1 // Copyright 2020, 2021 Francesco Biscani (bluescarni@gmail.com), Dario Izzo (dario.izzo@gmail.com)
2 //
3 // This file is part of the heyoka library.
4 //
5 // This Source Code Form is subject to the terms of the Mozilla
6 // Public License v. 2.0. If a copy of the MPL was not distributed
7 // with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
8
9 #ifndef HEYOKA_TAYLOR_HPP
10 #define HEYOKA_TAYLOR_HPP
11
12 #include <heyoka/config.hpp>
13
14 #include <cstddef>
15 #include <cstdint>
16 #include <functional>
17 #include <limits>
18 #include <memory>
19 #include <optional>
20 #include <ostream>
21 #include <stdexcept>
22 #include <string>
23 #include <tuple>
24 #include <type_traits>
25 #include <typeinfo>
26 #include <utility>
27 #include <variant>
28 #include <vector>
29
30 #if defined(HEYOKA_HAVE_REAL128)
31
32 #include <mp++/real128.hpp>
33
34 #endif
35
36 #include <heyoka/callable.hpp>
37 #include <heyoka/detail/dfloat.hpp>
38 #include <heyoka/detail/fwd_decl.hpp>
39 #include <heyoka/detail/igor.hpp>
40 #include <heyoka/detail/llvm_fwd.hpp>
41 #include <heyoka/detail/llvm_helpers.hpp>
42 #include <heyoka/detail/type_traits.hpp>
43 #include <heyoka/detail/visibility.hpp>
44 #include <heyoka/expression.hpp>
45 #include <heyoka/kw.hpp>
46 #include <heyoka/llvm_state.hpp>
47 #include <heyoka/number.hpp>
48 #include <heyoka/param.hpp>
49 #include <heyoka/s11n.hpp>
50 #include <heyoka/variable.hpp>
51
52 namespace heyoka
53 {
54
55 namespace detail
56 {
57
58 // NOTE: these are various utilities useful when dealing in a generic
59 // fashion with numbers/params in Taylor functions.
60
61 // Helper to detect if T is a number or a param.
62 template <typename T>
63 using is_num_param = std::disjunction<std::is_same<T, number>, std::is_same<T, param>>;
64
65 template <typename T>
66 inline constexpr bool is_num_param_v = is_num_param<T>::value;
67
68 HEYOKA_DLL_PUBLIC llvm::Value *taylor_codegen_numparam_dbl(llvm_state &, const number &, llvm::Value *, std::uint32_t);
69 HEYOKA_DLL_PUBLIC llvm::Value *taylor_codegen_numparam_ldbl(llvm_state &, const number &, llvm::Value *, std::uint32_t);
70
71 #if defined(HEYOKA_HAVE_REAL128)
72
73 HEYOKA_DLL_PUBLIC llvm::Value *taylor_codegen_numparam_f128(llvm_state &, const number &, llvm::Value *, std::uint32_t);
74
75 #endif
76
77 HEYOKA_DLL_PUBLIC llvm::Value *taylor_codegen_numparam_dbl(llvm_state &, const param &, llvm::Value *, std::uint32_t);
78 HEYOKA_DLL_PUBLIC llvm::Value *taylor_codegen_numparam_ldbl(llvm_state &, const param &, llvm::Value *, std::uint32_t);
79
80 #if defined(HEYOKA_HAVE_REAL128)
81
82 HEYOKA_DLL_PUBLIC llvm::Value *taylor_codegen_numparam_f128(llvm_state &, const param &, llvm::Value *, std::uint32_t);
83
84 #endif
85
86 template <typename T, typename U>
taylor_codegen_numparam(llvm_state & s,const U & n,llvm::Value * par_ptr,std::uint32_t batch_size)87 llvm::Value *taylor_codegen_numparam(llvm_state &s, const U &n, llvm::Value *par_ptr, std::uint32_t batch_size)
88 {
89 if constexpr (std::is_same_v<T, double>) {
90 return taylor_codegen_numparam_dbl(s, n, par_ptr, batch_size);
91 } else if constexpr (std::is_same_v<T, long double>) {
92 return taylor_codegen_numparam_ldbl(s, n, par_ptr, batch_size);
93 #if defined(HEYOKA_HAVE_REAL128)
94 } else if constexpr (std::is_same_v<T, mppp::real128>) {
95 return taylor_codegen_numparam_f128(s, n, par_ptr, batch_size);
96 #endif
97 } else {
98 static_assert(detail::always_false_v<T>, "Unhandled type.");
99 }
100 }
101
102 HEYOKA_DLL_PUBLIC llvm::Value *taylor_c_diff_numparam_codegen(llvm_state &, const number &, llvm::Value *,
103 llvm::Value *, std::uint32_t);
104 HEYOKA_DLL_PUBLIC llvm::Value *taylor_c_diff_numparam_codegen(llvm_state &, const param &, llvm::Value *, llvm::Value *,
105 std::uint32_t);
106
107 HEYOKA_DLL_PUBLIC llvm::Value *taylor_fetch_diff(const std::vector<llvm::Value *> &, std::uint32_t, std::uint32_t,
108 std::uint32_t);
109
110 HEYOKA_DLL_PUBLIC llvm::Value *taylor_c_load_diff(llvm_state &, llvm::Value *, std::uint32_t, llvm::Value *,
111 llvm::Value *);
112
113 HEYOKA_DLL_PUBLIC std::pair<std::string, std::vector<llvm::Type *>>
114 taylor_c_diff_func_name_args_impl(llvm::LLVMContext &, const std::string &, llvm::Type *, std::uint32_t,
115 const std::vector<std::variant<variable, number, param>> &, std::uint32_t);
116
117 // NOTE: this function will return a pair containing:
118 //
119 // - the mangled name and
120 // - the list of LLVM argument types
121 //
122 // for the function implementing the Taylor derivative in compact mode of the mathematical function
123 // called "name". The mangled name is assembled from "name", the types of the arguments args, the number
124 // of uvars and the scalar or vector floating-point type in use (which depends on T and batch_size).
125 template <typename T>
126 inline std::pair<std::string, std::vector<llvm::Type *>>
taylor_c_diff_func_name_args(llvm::LLVMContext & c,const std::string & name,std::uint32_t n_uvars,std::uint32_t batch_size,const std::vector<std::variant<variable,number,param>> & args,std::uint32_t n_hidden_deps=0)127 taylor_c_diff_func_name_args(llvm::LLVMContext &c, const std::string &name, std::uint32_t n_uvars,
128 std::uint32_t batch_size, const std::vector<std::variant<variable, number, param>> &args,
129 std::uint32_t n_hidden_deps = 0)
130 {
131 // Fetch the floating-point type.
132 auto val_t = to_llvm_vector_type<T>(c, batch_size);
133
134 return taylor_c_diff_func_name_args_impl(c, name, val_t, n_uvars, args, n_hidden_deps);
135 }
136
137 // Add a function for computing the dense output
138 // via polynomial evaluation.
139 template <typename T>
140 void taylor_add_d_out_function(llvm_state &, std::uint32_t, std::uint32_t, std::uint32_t, bool, bool = true,
141 bool = true);
142
143 } // namespace detail
144
145 HEYOKA_DLL_PUBLIC std::pair<taylor_dc_t, std::vector<std::uint32_t>> taylor_decompose(const std::vector<expression> &,
146 const std::vector<expression> &);
147 HEYOKA_DLL_PUBLIC std::pair<taylor_dc_t, std::vector<std::uint32_t>>
148 taylor_decompose(const std::vector<std::pair<expression, expression>> &, const std::vector<expression> &);
149
150 HEYOKA_DLL_PUBLIC taylor_dc_t taylor_add_jet_dbl(llvm_state &, const std::string &, const std::vector<expression> &,
151 std::uint32_t, std::uint32_t, bool, bool,
152 const std::vector<expression> &);
153 HEYOKA_DLL_PUBLIC taylor_dc_t taylor_add_jet_ldbl(llvm_state &, const std::string &, const std::vector<expression> &,
154 std::uint32_t, std::uint32_t, bool, bool,
155 const std::vector<expression> &);
156
157 #if defined(HEYOKA_HAVE_REAL128)
158
159 HEYOKA_DLL_PUBLIC taylor_dc_t taylor_add_jet_f128(llvm_state &, const std::string &, const std::vector<expression> &,
160 std::uint32_t, std::uint32_t, bool, bool,
161 const std::vector<expression> &);
162
163 #endif
164
165 template <typename T>
taylor_add_jet(llvm_state & s,const std::string & name,const std::vector<expression> & sys,std::uint32_t order,std::uint32_t batch_size,bool high_accuracy,bool compact_mode,const std::vector<expression> & sv_funcs={})166 taylor_dc_t taylor_add_jet(llvm_state &s, const std::string &name, const std::vector<expression> &sys,
167 std::uint32_t order, std::uint32_t batch_size, bool high_accuracy, bool compact_mode,
168 const std::vector<expression> &sv_funcs = {})
169 {
170 if constexpr (std::is_same_v<T, double>) {
171 return taylor_add_jet_dbl(s, name, sys, order, batch_size, high_accuracy, compact_mode, sv_funcs);
172 } else if constexpr (std::is_same_v<T, long double>) {
173 return taylor_add_jet_ldbl(s, name, sys, order, batch_size, high_accuracy, compact_mode, sv_funcs);
174 #if defined(HEYOKA_HAVE_REAL128)
175 } else if constexpr (std::is_same_v<T, mppp::real128>) {
176 return taylor_add_jet_f128(s, name, sys, order, batch_size, high_accuracy, compact_mode, sv_funcs);
177 #endif
178 } else {
179 static_assert(detail::always_false_v<T>, "Unhandled type.");
180 }
181 }
182
183 HEYOKA_DLL_PUBLIC taylor_dc_t taylor_add_jet_dbl(llvm_state &, const std::string &,
184 const std::vector<std::pair<expression, expression>> &, std::uint32_t,
185 std::uint32_t, bool, bool, const std::vector<expression> &);
186 HEYOKA_DLL_PUBLIC taylor_dc_t taylor_add_jet_ldbl(llvm_state &, const std::string &,
187 const std::vector<std::pair<expression, expression>> &, std::uint32_t,
188 std::uint32_t, bool, bool, const std::vector<expression> &);
189
190 #if defined(HEYOKA_HAVE_REAL128)
191
192 HEYOKA_DLL_PUBLIC taylor_dc_t taylor_add_jet_f128(llvm_state &, const std::string &,
193 const std::vector<std::pair<expression, expression>> &, std::uint32_t,
194 std::uint32_t, bool, bool, const std::vector<expression> &);
195
196 #endif
197
198 template <typename T>
taylor_add_jet(llvm_state & s,const std::string & name,const std::vector<std::pair<expression,expression>> & sys,std::uint32_t order,std::uint32_t batch_size,bool high_accuracy,bool compact_mode,const std::vector<expression> & sv_funcs={})199 taylor_dc_t taylor_add_jet(llvm_state &s, const std::string &name,
200 const std::vector<std::pair<expression, expression>> &sys, std::uint32_t order,
201 std::uint32_t batch_size, bool high_accuracy, bool compact_mode,
202 const std::vector<expression> &sv_funcs = {})
203 {
204 if constexpr (std::is_same_v<T, double>) {
205 return taylor_add_jet_dbl(s, name, sys, order, batch_size, high_accuracy, compact_mode, sv_funcs);
206 } else if constexpr (std::is_same_v<T, long double>) {
207 return taylor_add_jet_ldbl(s, name, sys, order, batch_size, high_accuracy, compact_mode, sv_funcs);
208 #if defined(HEYOKA_HAVE_REAL128)
209 } else if constexpr (std::is_same_v<T, mppp::real128>) {
210 return taylor_add_jet_f128(s, name, sys, order, batch_size, high_accuracy, compact_mode, sv_funcs);
211 #endif
212 } else {
213 static_assert(detail::always_false_v<T>, "Unhandled type.");
214 }
215 }
216
217 // Enum to represent the outcome of a stepping/propagate function.
218 enum class taylor_outcome : std::int64_t {
219 // NOTE: we make these enums start at -2**32 - 1,
220 // so that we have 2**32 values in the [-2**32, -1]
221 // range to use for signalling stopping terminal events.
222 success = -4294967296ll - 1, // Integration step was successful, no time/step limits were reached.
223 step_limit = -4294967296ll - 2, // Maximum number of steps reached.
224 time_limit = -4294967296ll - 3, // Time limit reached.
225 err_nf_state = -4294967296ll - 4, // Non-finite state detected at the end of the timestep.
226 cb_stop = -4294967296ll - 5 // Propagation stopped by callback.
227 };
228
229 HEYOKA_DLL_PUBLIC std::ostream &operator<<(std::ostream &, taylor_outcome);
230
231 HEYOKA_DLL_PUBLIC std::ostream &operator<<(std::ostream &, event_direction);
232
233 } // namespace heyoka
234
235 // NOTE: implement a workaround for the serialisation of tuples whose first element
236 // is a taylor outcome. We need this because Boost.Serialization treats all enums
237 // as ints, which is not ok for taylor_outcome (whose underyling type will not
238 // be an int on most platforms). Because it is not possible to override Boost's
239 // enum implementation, we override the serialisation of tuples with outcomes
240 // as first elements, which is all we need in the serialisation of the batch
241 // integrator. The implementation below will be preferred over the generic tuple
242 // s11n because it is more specialised.
243 // NOTE: this workaround is not necessary for the other enums in heyoka because
244 // those all have ints as underlying type.
245 namespace boost
246 {
247
248 namespace serialization
249 {
250
251 template <typename Archive, typename... Args>
save(Archive & ar,const std::tuple<heyoka::taylor_outcome,Args...> & tup,unsigned)252 inline void save(Archive &ar, const std::tuple<heyoka::taylor_outcome, Args...> &tup, unsigned)
253 {
254 auto tf = [&ar](const auto &x) {
255 if constexpr (std::is_same_v<decltype(x), const heyoka::taylor_outcome &>) {
256 ar << static_cast<std::underlying_type_t<heyoka::taylor_outcome>>(x);
257 } else {
258 ar << x;
259 }
260 };
261
262 std::apply([&tf](const auto &...x) { (tf(x), ...); }, tup);
263 }
264
265 template <typename Archive, typename... Args>
load(Archive & ar,std::tuple<heyoka::taylor_outcome,Args...> & tup,unsigned)266 inline void load(Archive &ar, std::tuple<heyoka::taylor_outcome, Args...> &tup, unsigned)
267 {
268 auto tf = [&ar](auto &x) {
269 if constexpr (std::is_same_v<decltype(x), heyoka::taylor_outcome &>) {
270 std::underlying_type_t<heyoka::taylor_outcome> val{};
271 ar >> val;
272
273 x = static_cast<heyoka::taylor_outcome>(val);
274 } else {
275 ar >> x;
276 }
277 };
278
279 std::apply([&tf](auto &...x) { (tf(x), ...); }, tup);
280 }
281
282 template <typename Archive, typename... Args>
serialize(Archive & ar,std::tuple<heyoka::taylor_outcome,Args...> & tup,unsigned v)283 inline void serialize(Archive &ar, std::tuple<heyoka::taylor_outcome, Args...> &tup, unsigned v)
284 {
285 split_free(ar, tup, v);
286 }
287
288 } // namespace serialization
289
290 } // namespace boost
291
292 namespace heyoka
293 {
294
295 namespace kw
296 {
297
298 IGOR_MAKE_NAMED_ARGUMENT(tol);
299 IGOR_MAKE_NAMED_ARGUMENT(high_accuracy);
300 IGOR_MAKE_NAMED_ARGUMENT(compact_mode);
301 IGOR_MAKE_NAMED_ARGUMENT(pars);
302 IGOR_MAKE_NAMED_ARGUMENT(t_events);
303 IGOR_MAKE_NAMED_ARGUMENT(nt_events);
304
305 // NOTE: these are used for constructing events.
306 IGOR_MAKE_NAMED_ARGUMENT(callback);
307 IGOR_MAKE_NAMED_ARGUMENT(cooldown);
308 IGOR_MAKE_NAMED_ARGUMENT(direction);
309
310 // NOTE: these are used in the
311 // propagate_*() functions.
312 IGOR_MAKE_NAMED_ARGUMENT(max_steps);
313 IGOR_MAKE_NAMED_ARGUMENT(max_delta_t);
314 IGOR_MAKE_NAMED_ARGUMENT(write_tc);
315 IGOR_MAKE_NAMED_ARGUMENT(c_output);
316
317 } // namespace kw
318
319 namespace detail
320 {
321
322 // Helper for parsing common options for the Taylor integrators.
323 template <typename T, typename... KwArgs>
taylor_adaptive_common_ops(KwArgs &&...kw_args)324 inline auto taylor_adaptive_common_ops(KwArgs &&...kw_args)
325 {
326 igor::parser p{kw_args...};
327
328 // High accuracy mode (defaults to false).
329 auto high_accuracy = [&p]() -> bool {
330 if constexpr (p.has(kw::high_accuracy)) {
331 return std::forward<decltype(p(kw::high_accuracy))>(p(kw::high_accuracy));
332 } else {
333 return false;
334 }
335 }();
336
337 // tol (defaults to eps).
338 auto tol = [&p]() -> T {
339 if constexpr (p.has(kw::tol)) {
340 auto retval = std::forward<decltype(p(kw::tol))>(p(kw::tol));
341 if (retval != T(0)) {
342 // NOTE: this covers the NaN case as well.
343 return retval;
344 }
345 // NOTE: zero tolerance will be interpreted
346 // as automatically-deduced by falling through
347 // the code below.
348 }
349
350 return std::numeric_limits<T>::epsilon();
351 }();
352
353 // Compact mode (defaults to false).
354 auto compact_mode = [&p]() -> bool {
355 if constexpr (p.has(kw::compact_mode)) {
356 return std::forward<decltype(p(kw::compact_mode))>(p(kw::compact_mode));
357 } else {
358 return false;
359 }
360 }();
361
362 // Vector of parameters (defaults to empty vector).
363 auto pars = [&p]() -> std::vector<T> {
364 if constexpr (p.has(kw::pars)) {
365 return std::forward<decltype(p(kw::pars))>(p(kw::pars));
366 } else {
367 return {};
368 }
369 }();
370
371 return std::tuple{high_accuracy, tol, compact_mode, std::move(pars)};
372 }
373
374 template <typename T, bool B>
375 class HEYOKA_DLL_PUBLIC nt_event_impl
376 {
377 static_assert(is_supported_fp_v<T>, "Unhandled type.");
378
379 public:
380 using callback_t = callable<std::conditional_t<B, void(taylor_adaptive_batch_impl<T> &, T, int, std::uint32_t),
381 void(taylor_adaptive_impl<T> &, T, int)>>;
382
383 private:
384 expression eq;
385 callback_t callback;
386 event_direction dir;
387
388 // Serialization.
389 friend class boost::serialization::access;
390 template <typename Archive>
serialize(Archive & ar,unsigned)391 void serialize(Archive &ar, unsigned)
392 {
393 ar &eq;
394 ar &callback;
395 ar &dir;
396 }
397
398 void finalise_ctor(event_direction);
399
400 public:
401 nt_event_impl();
402
403 template <typename... KwArgs>
nt_event_impl(const expression & e,callback_t cb,KwArgs &&...kw_args)404 explicit nt_event_impl(const expression &e, callback_t cb, KwArgs &&...kw_args)
405 : eq(copy(e)), callback(std::move(cb))
406 {
407 igor::parser p{kw_args...};
408
409 if constexpr (p.has_unnamed_arguments()) {
410 static_assert(detail::always_false_v<KwArgs...>,
411 "The variadic arguments in the construction of a non-terminal event contain "
412 "unnamed arguments.");
413 throw;
414 } else {
415 // Direction (defaults to any).
416 auto d = [&p]() -> event_direction {
417 if constexpr (p.has(kw::direction)) {
418 return std::forward<decltype(p(kw::direction))>(p(kw::direction));
419 } else {
420 return event_direction::any;
421 }
422 }();
423
424 finalise_ctor(d);
425 }
426 }
427
428 nt_event_impl(const nt_event_impl &);
429 nt_event_impl(nt_event_impl &&) noexcept;
430
431 nt_event_impl &operator=(const nt_event_impl &);
432 nt_event_impl &operator=(nt_event_impl &&) noexcept;
433
434 ~nt_event_impl();
435
436 const expression &get_expression() const;
437 const callback_t &get_callback() const;
438 event_direction get_direction() const;
439 };
440
441 template <typename T, bool B>
operator <<(std::ostream & os,const nt_event_impl<T,B> &)442 inline std::ostream &operator<<(std::ostream &os, const nt_event_impl<T, B> &)
443 {
444 static_assert(always_false_v<T>, "Unhandled type.");
445
446 return os;
447 }
448
449 template <>
450 HEYOKA_DLL_PUBLIC std::ostream &operator<<(std::ostream &, const nt_event_impl<double, false> &);
451 template <>
452 HEYOKA_DLL_PUBLIC std::ostream &operator<<(std::ostream &, const nt_event_impl<double, true> &);
453
454 template <>
455 HEYOKA_DLL_PUBLIC std::ostream &operator<<(std::ostream &, const nt_event_impl<long double, false> &);
456 template <>
457 HEYOKA_DLL_PUBLIC std::ostream &operator<<(std::ostream &, const nt_event_impl<long double, true> &);
458
459 #if defined(HEYOKA_HAVE_REAL128)
460
461 template <>
462 HEYOKA_DLL_PUBLIC std::ostream &operator<<(std::ostream &, const nt_event_impl<mppp::real128, false> &);
463 template <>
464 HEYOKA_DLL_PUBLIC std::ostream &operator<<(std::ostream &, const nt_event_impl<mppp::real128, true> &);
465
466 #endif
467
468 template <typename T, bool B>
469 class HEYOKA_DLL_PUBLIC t_event_impl
470 {
471 static_assert(is_supported_fp_v<T>, "Unhandled type.");
472
473 public:
474 using callback_t = callable<std::conditional_t<B, bool(taylor_adaptive_batch_impl<T> &, bool, int, std::uint32_t),
475 bool(taylor_adaptive_impl<T> &, bool, int)>>;
476
477 private:
478 expression eq;
479 callback_t callback;
480 T cooldown;
481 event_direction dir;
482
483 // Serialization.
484 friend class boost::serialization::access;
485 template <typename Archive>
serialize(Archive & ar,unsigned)486 void serialize(Archive &ar, unsigned)
487 {
488 ar &eq;
489 ar &callback;
490 ar &cooldown;
491 ar &dir;
492 }
493
494 void finalise_ctor(callback_t, T, event_direction);
495
496 public:
497 t_event_impl();
498
499 template <typename... KwArgs>
t_event_impl(const expression & e,KwArgs &&...kw_args)500 explicit t_event_impl(const expression &e, KwArgs &&...kw_args) : eq(copy(e))
501 {
502 igor::parser p{kw_args...};
503
504 if constexpr (p.has_unnamed_arguments()) {
505 static_assert(detail::always_false_v<KwArgs...>,
506 "The variadic arguments in the construction of a terminal event contain "
507 "unnamed arguments.");
508 throw;
509 } else {
510 // Callback (defaults to empty).
511 auto cb = [&p]() -> callback_t {
512 if constexpr (p.has(kw::callback)) {
513 return std::forward<decltype(p(kw::callback))>(p(kw::callback));
514 } else {
515 return {};
516 }
517 }();
518
519 // Cooldown (defaults to -1).
520 auto cd = [&p]() -> T {
521 if constexpr (p.has(kw::cooldown)) {
522 return std::forward<decltype(p(kw::cooldown))>(p(kw::cooldown));
523 } else {
524 return T(-1);
525 }
526 }();
527
528 // Direction (defaults to any).
529 auto d = [&p]() -> event_direction {
530 if constexpr (p.has(kw::direction)) {
531 return std::forward<decltype(p(kw::direction))>(p(kw::direction));
532 } else {
533 return event_direction::any;
534 }
535 }();
536
537 finalise_ctor(std::move(cb), cd, d);
538 }
539 }
540
541 t_event_impl(const t_event_impl &);
542 t_event_impl(t_event_impl &&) noexcept;
543
544 t_event_impl &operator=(const t_event_impl &);
545 t_event_impl &operator=(t_event_impl &&) noexcept;
546
547 ~t_event_impl();
548
549 const expression &get_expression() const;
550 const callback_t &get_callback() const;
551 event_direction get_direction() const;
552 T get_cooldown() const;
553 };
554
555 template <typename T, bool B>
operator <<(std::ostream & os,const t_event_impl<T,B> &)556 inline std::ostream &operator<<(std::ostream &os, const t_event_impl<T, B> &)
557 {
558 static_assert(always_false_v<T>, "Unhandled type.");
559
560 return os;
561 }
562
563 template <>
564 HEYOKA_DLL_PUBLIC std::ostream &operator<<(std::ostream &, const t_event_impl<double, false> &);
565 template <>
566 HEYOKA_DLL_PUBLIC std::ostream &operator<<(std::ostream &, const t_event_impl<double, true> &);
567
568 template <>
569 HEYOKA_DLL_PUBLIC std::ostream &operator<<(std::ostream &, const t_event_impl<long double, false> &);
570 template <>
571 HEYOKA_DLL_PUBLIC std::ostream &operator<<(std::ostream &, const t_event_impl<long double, true> &);
572
573 #if defined(HEYOKA_HAVE_REAL128)
574
575 template <>
576 HEYOKA_DLL_PUBLIC std::ostream &operator<<(std::ostream &, const t_event_impl<mppp::real128, false> &);
577 template <>
578 HEYOKA_DLL_PUBLIC std::ostream &operator<<(std::ostream &, const t_event_impl<mppp::real128, true> &);
579
580 #endif
581
582 } // namespace detail
583
584 template <typename T>
585 using nt_event = detail::nt_event_impl<T, false>;
586
587 template <typename T>
588 using t_event = detail::t_event_impl<T, false>;
589
590 template <typename T>
591 using nt_event_batch = detail::nt_event_impl<T, true>;
592
593 template <typename T>
594 using t_event_batch = detail::t_event_impl<T, true>;
595
596 template <typename>
597 class HEYOKA_DLL_PUBLIC continuous_output;
598
599 namespace detail
600 {
601
602 template <typename T>
603 std::ostream &c_out_stream_impl(std::ostream &, const continuous_output<T> &);
604
605 }
606
607 template <typename T>
608 class HEYOKA_DLL_PUBLIC continuous_output
609 {
610 static_assert(detail::is_supported_fp_v<T>, "Unhandled type.");
611
612 template <typename>
613 friend class HEYOKA_DLL_PUBLIC detail::taylor_adaptive_impl;
614
615 friend std::ostream &detail::c_out_stream_impl<T>(std::ostream &, const continuous_output<T> &);
616
617 llvm_state m_llvm_state;
618 std::vector<T> m_tcs;
619 std::vector<T> m_times_hi, m_times_lo;
620 std::vector<T> m_output;
621 using fptr_t = void (*)(T *, T, const T *, const T *, const T *);
622 fptr_t m_f_ptr = nullptr;
623
624 HEYOKA_DLL_LOCAL void add_c_out_function(std::uint32_t, std::uint32_t, bool);
625 void call_impl(T);
626
627 // Serialisation.
628 friend class boost::serialization::access;
629 void save(boost::archive::binary_oarchive &, unsigned) const;
630 void load(boost::archive::binary_iarchive &, unsigned);
631 BOOST_SERIALIZATION_SPLIT_MEMBER()
632
633 public:
634 continuous_output();
635 explicit continuous_output(llvm_state &&);
636 continuous_output(const continuous_output &);
637 continuous_output(continuous_output &&) noexcept;
638 ~continuous_output();
639
640 continuous_output &operator=(const continuous_output &);
641 continuous_output &operator=(continuous_output &&) noexcept;
642
643 const llvm_state &get_llvm_state() const;
644
operator ()(T tm)645 const std::vector<T> &operator()(T tm)
646 {
647 call_impl(tm);
648 return m_output;
649 }
get_output() const650 const std::vector<T> &get_output() const
651 {
652 return m_output;
653 }
get_times() const654 const std::vector<T> &get_times() const
655 {
656 return m_times_hi;
657 }
get_tcs() const658 const std::vector<T> &get_tcs() const
659 {
660 return m_tcs;
661 }
662
663 std::pair<T, T> get_bounds() const;
664 std::size_t get_n_steps() const;
665 };
666
667 template <typename T>
operator <<(std::ostream & os,const continuous_output<T> &)668 inline std::ostream &operator<<(std::ostream &os, const continuous_output<T> &)
669 {
670 static_assert(detail::always_false_v<T>, "Unhandled type.");
671
672 return os;
673 }
674
675 template <>
676 HEYOKA_DLL_PUBLIC std::ostream &operator<<(std::ostream &, const continuous_output<double> &);
677
678 template <>
679 HEYOKA_DLL_PUBLIC std::ostream &operator<<(std::ostream &, const continuous_output<long double> &);
680
681 #if defined(HEYOKA_HAVE_REAL128)
682
683 template <>
684 HEYOKA_DLL_PUBLIC std::ostream &operator<<(std::ostream &, const continuous_output<mppp::real128> &);
685
686 #endif
687
688 template <typename>
689 class HEYOKA_DLL_PUBLIC continuous_output_batch;
690
691 namespace detail
692 {
693
694 template <typename T>
695 std::ostream &c_out_batch_stream_impl(std::ostream &, const continuous_output_batch<T> &);
696
697 }
698
699 template <typename T>
700 class HEYOKA_DLL_PUBLIC continuous_output_batch
701 {
702 static_assert(detail::is_supported_fp_v<T>, "Unhandled type.");
703
704 template <typename>
705 friend class HEYOKA_DLL_PUBLIC detail::taylor_adaptive_batch_impl;
706
707 friend std::ostream &detail::c_out_batch_stream_impl<T>(std::ostream &, const continuous_output_batch<T> &);
708
709 std::uint32_t m_batch_size = 0;
710 llvm_state m_llvm_state;
711 std::vector<T> m_tcs;
712 std::vector<T> m_times_hi, m_times_lo;
713 std::vector<T> m_output;
714 std::vector<T> m_tmp_tm;
715 using fptr_t = void (*)(T *, const T *, const T *, const T *, const T *);
716 fptr_t m_f_ptr = nullptr;
717
718 HEYOKA_DLL_LOCAL void add_c_out_function(std::uint32_t, std::uint32_t, bool);
719 void call_impl(const T *);
720
721 // Serialisation.
722 friend class boost::serialization::access;
723 void save(boost::archive::binary_oarchive &, unsigned) const;
724 void load(boost::archive::binary_iarchive &, unsigned);
725 BOOST_SERIALIZATION_SPLIT_MEMBER()
726
727 public:
728 continuous_output_batch();
729 explicit continuous_output_batch(llvm_state &&);
730 continuous_output_batch(const continuous_output_batch &);
731 continuous_output_batch(continuous_output_batch &&) noexcept;
732 ~continuous_output_batch();
733
734 continuous_output_batch &operator=(const continuous_output_batch &);
735 continuous_output_batch &operator=(continuous_output_batch &&) noexcept;
736
737 const llvm_state &get_llvm_state() const;
738
operator ()(const T * tm)739 const std::vector<T> &operator()(const T *tm)
740 {
741 call_impl(tm);
742 return m_output;
743 }
744 const std::vector<T> &operator()(const std::vector<T> &);
745
get_output() const746 const std::vector<T> &get_output() const
747 {
748 return m_output;
749 }
750 // NOTE: when documenting this function,
751 // we need to warn about the padding.
get_times() const752 const std::vector<T> &get_times() const
753 {
754 return m_times_hi;
755 }
get_tcs() const756 const std::vector<T> &get_tcs() const
757 {
758 return m_tcs;
759 }
760 std::uint32_t get_batch_size() const;
761
762 std::pair<std::vector<T>, std::vector<T>> get_bounds() const;
763 std::size_t get_n_steps() const;
764 };
765
766 template <typename T>
operator <<(std::ostream & os,const continuous_output_batch<T> &)767 inline std::ostream &operator<<(std::ostream &os, const continuous_output_batch<T> &)
768 {
769 static_assert(detail::always_false_v<T>, "Unhandled type.");
770
771 return os;
772 }
773
774 template <>
775 HEYOKA_DLL_PUBLIC std::ostream &operator<<(std::ostream &, const continuous_output_batch<double> &);
776
777 template <>
778 HEYOKA_DLL_PUBLIC std::ostream &operator<<(std::ostream &, const continuous_output_batch<long double> &);
779
780 #if defined(HEYOKA_HAVE_REAL128)
781
782 template <>
783 HEYOKA_DLL_PUBLIC std::ostream &operator<<(std::ostream &, const continuous_output_batch<mppp::real128> &);
784
785 #endif
786
787 namespace detail
788 {
789
790 // Polynomial cache type. Each entry is a polynomial
791 // represented as a vector of coefficients. Used
792 // during event detection.
793 template <typename T>
794 using taylor_poly_cache = std::vector<std::vector<T>>;
795
796 // A RAII helper to extract polys from a cache and
797 // return them to the cache upon destruction. Used
798 // during event detection.
799 template <typename>
800 class taylor_pwrap;
801
802 template <typename T>
803 class HEYOKA_DLL_PUBLIC taylor_adaptive_impl
804 {
805 static_assert(is_supported_fp_v<T>, "Unhandled type.");
806
807 public:
808 using nt_event_t = nt_event<T>;
809 using t_event_t = t_event<T>;
810
811 private:
812 // Struct implementing the data/logic for event detection.
813 struct HEYOKA_DLL_PUBLIC ed_data {
814 // The working list type used during real root isolation.
815 using wlist_t = std::vector<std::tuple<T, T, taylor_pwrap<T>>>;
816 // The type used to store the list of isolating intervals.
817 using isol_t = std::vector<std::tuple<T, T>>;
818 // Polynomial translation function type.
819 using pt_t = void (*)(T *, const T *);
820 // rtscc function type.
821 using rtscc_t = void (*)(T *, T *, std::uint32_t *, const T *);
822 // fex_check function type.
823 using fex_check_t = void (*)(const T *, const T *, const std::uint32_t *, std::uint32_t *);
824
825 // The vector of terminal events.
826 std::vector<t_event_t> m_tes;
827 // The vector of non-terminal events.
828 std::vector<nt_event_t> m_ntes;
829 // The jet of derivatives for the state variables
830 // and the events.
831 std::vector<T> m_ev_jet;
832 // Vector of detected terminal events.
833 std::vector<std::tuple<std::uint32_t, T, bool, int, T>> m_d_tes;
834 // The vector of cooldowns for the terminal events.
835 // If an event is on cooldown, the corresponding optional
836 // in this vector will contain the total time elapsed
837 // since the cooldown started and the absolute value
838 // of the cooldown duration.
839 std::vector<std::optional<std::pair<T, T>>> m_te_cooldowns;
840 // Vector of detected non-terminal events.
841 std::vector<std::tuple<std::uint32_t, T, int>> m_d_ntes;
842 // The LLVM state.
843 llvm_state m_state;
844 // The JIT compiled functions used during root finding.
845 // NOTE: use default member initializers to ensure that
846 // these are zero-inited by the default constructor
847 // (which is defaulted).
848 pt_t m_pt = nullptr;
849 rtscc_t m_rtscc = nullptr;
850 fex_check_t m_fex_check = nullptr;
851 // The working list.
852 wlist_t m_wlist;
853 // The list of isolating intervals.
854 isol_t m_isol;
855 // The polynomial cache.
856 taylor_poly_cache<T> m_poly_cache;
857
858 // Constructors.
859 ed_data(std::vector<t_event_t>, std::vector<nt_event_t>, std::uint32_t, std::uint32_t);
860 ed_data(const ed_data &);
861 ~ed_data();
862
863 // Delete unused bits.
864 ed_data(ed_data &&) = delete;
865 ed_data &operator=(const ed_data &) = delete;
866 ed_data &operator=(ed_data &&) = delete;
867
868 // The event detection function.
869 void detect_events(T, std::uint32_t, std::uint32_t, T);
870
871 private:
872 // Serialisation.
873 // NOTE: the def ctor is used only during deserialisation
874 // via pointer.
875 ed_data();
876 friend class boost::serialization::access;
877 void save(boost::archive::binary_oarchive &, unsigned) const;
878 void load(boost::archive::binary_iarchive &, unsigned);
879 BOOST_SERIALIZATION_SPLIT_MEMBER()
880 };
881
882 // State vector.
883 std::vector<T> m_state;
884 // Time.
885 dfloat<T> m_time;
886 // The LLVM machinery.
887 llvm_state m_llvm;
888 // Dimension of the system.
889 std::uint32_t m_dim;
890 // Taylor decomposition.
891 taylor_dc_t m_dc;
892 // Taylor order.
893 std::uint32_t m_order;
894 // Tolerance.
895 T m_tol;
896 // High accuracy.
897 bool m_high_accuracy;
898 // Compact mode.
899 bool m_compact_mode;
900 // The steppers.
901 using step_f_t = void (*)(T *, const T *, const T *, T *, T *);
902 using step_f_e_t = void (*)(T *, const T *, const T *, const T *, T *, T *);
903 std::variant<step_f_t, step_f_e_t> m_step_f;
904 // The vector of parameters.
905 std::vector<T> m_pars;
906 // The vector for the Taylor coefficients.
907 std::vector<T> m_tc;
908 // Size of the last timestep taken.
909 T m_last_h = T(0);
910 // The function for computing the dense output.
911 using d_out_f_t = void (*)(T *, const T *, const T *);
912 d_out_f_t m_d_out_f;
913 // The vector for the dense output.
914 std::vector<T> m_d_out;
915 // Auxiliary data/functions for event detection.
916 std::unique_ptr<ed_data> m_ed_data;
917
918 // Serialization.
919 template <typename Archive>
920 HEYOKA_DLL_LOCAL void save_impl(Archive &, unsigned) const;
921 template <typename Archive>
922 HEYOKA_DLL_LOCAL void load_impl(Archive &, unsigned);
923
924 friend class boost::serialization::access;
925 void save(boost::archive::binary_oarchive &, unsigned) const;
926 void load(boost::archive::binary_iarchive &, unsigned);
927 BOOST_SERIALIZATION_SPLIT_MEMBER()
928
929 HEYOKA_DLL_LOCAL std::tuple<taylor_outcome, T> step_impl(T, bool);
930
931 // Private implementation-detail constructor machinery.
932 // NOTE: apparently on Windows we need to re-iterate
933 // here that this is going to be dll-exported.
934 template <typename U>
935 HEYOKA_DLL_PUBLIC void finalise_ctor_impl(const U &, std::vector<T>, T, T, bool, bool, std::vector<T>,
936 std::vector<t_event_t>, std::vector<nt_event_t>);
937 template <typename U, typename... KwArgs>
finalise_ctor(const U & sys,std::vector<T> state,KwArgs &&...kw_args)938 void finalise_ctor(const U &sys, std::vector<T> state, KwArgs &&...kw_args)
939 {
940 igor::parser p{kw_args...};
941
942 if constexpr (p.has_unnamed_arguments()) {
943 static_assert(detail::always_false_v<KwArgs...>,
944 "The variadic arguments in the construction of an adaptive Taylor integrator contain "
945 "unnamed arguments.");
946 } else {
947 // Initial time (defaults to zero).
948 const auto time = [&p]() -> T {
949 if constexpr (p.has(kw::time)) {
950 return std::forward<decltype(p(kw::time))>(p(kw::time));
951 } else {
952 return T(0);
953 }
954 }();
955
956 auto [high_accuracy, tol, compact_mode, pars]
957 = taylor_adaptive_common_ops<T>(std::forward<KwArgs>(kw_args)...);
958
959 // Extract the terminal events, if any.
960 auto tes = [&p]() -> std::vector<t_event_t> {
961 if constexpr (p.has(kw::t_events)) {
962 return std::forward<decltype(p(kw::t_events))>(p(kw::t_events));
963 } else {
964 return {};
965 }
966 }();
967
968 // Extract the non-terminal events, if any.
969 auto ntes = [&p]() -> std::vector<nt_event_t> {
970 if constexpr (p.has(kw::nt_events)) {
971 return std::forward<decltype(p(kw::nt_events))>(p(kw::nt_events));
972 } else {
973 return {};
974 }
975 }();
976
977 finalise_ctor_impl(sys, std::move(state), time, tol, high_accuracy, compact_mode, std::move(pars),
978 std::move(tes), std::move(ntes));
979 }
980 }
981
982 public:
983 taylor_adaptive_impl();
984
985 template <typename... KwArgs>
taylor_adaptive_impl(const std::vector<expression> & sys,std::vector<T> state,KwArgs &&...kw_args)986 explicit taylor_adaptive_impl(const std::vector<expression> &sys, std::vector<T> state, KwArgs &&...kw_args)
987 : m_llvm{std::forward<KwArgs>(kw_args)...}
988 {
989 finalise_ctor(sys, std::move(state), std::forward<KwArgs>(kw_args)...);
990 }
991 template <typename... KwArgs>
taylor_adaptive_impl(const std::vector<std::pair<expression,expression>> & sys,std::vector<T> state,KwArgs &&...kw_args)992 explicit taylor_adaptive_impl(const std::vector<std::pair<expression, expression>> &sys, std::vector<T> state,
993 KwArgs &&...kw_args)
994 : m_llvm{std::forward<KwArgs>(kw_args)...}
995 {
996 finalise_ctor(sys, std::move(state), std::forward<KwArgs>(kw_args)...);
997 }
998
999 taylor_adaptive_impl(const taylor_adaptive_impl &);
1000 taylor_adaptive_impl(taylor_adaptive_impl &&) noexcept;
1001
1002 taylor_adaptive_impl &operator=(const taylor_adaptive_impl &);
1003 taylor_adaptive_impl &operator=(taylor_adaptive_impl &&) noexcept;
1004
1005 ~taylor_adaptive_impl();
1006
1007 const llvm_state &get_llvm_state() const;
1008
1009 const taylor_dc_t &get_decomposition() const;
1010
1011 std::uint32_t get_order() const;
1012 T get_tol() const;
1013 bool get_high_accuracy() const;
1014 bool get_compact_mode() const;
1015 std::uint32_t get_dim() const;
1016
get_time() const1017 T get_time() const
1018 {
1019 return static_cast<T>(m_time);
1020 }
set_time(T t)1021 void set_time(T t)
1022 {
1023 m_time = dfloat<T>(t);
1024 }
1025
get_state() const1026 const std::vector<T> &get_state() const
1027 {
1028 return m_state;
1029 }
get_state_data() const1030 const T *get_state_data() const
1031 {
1032 return m_state.data();
1033 }
get_state_data()1034 T *get_state_data()
1035 {
1036 return m_state.data();
1037 }
1038
get_pars() const1039 const std::vector<T> &get_pars() const
1040 {
1041 return m_pars;
1042 }
get_pars_data() const1043 const T *get_pars_data() const
1044 {
1045 return m_pars.data();
1046 }
get_pars_data()1047 T *get_pars_data()
1048 {
1049 return m_pars.data();
1050 }
1051
get_tc() const1052 const std::vector<T> &get_tc() const
1053 {
1054 return m_tc;
1055 }
1056
get_last_h() const1057 T get_last_h() const
1058 {
1059 return m_last_h;
1060 }
1061
get_d_output() const1062 const std::vector<T> &get_d_output() const
1063 {
1064 return m_d_out;
1065 }
1066 const std::vector<T> &update_d_output(T, bool = false);
1067
with_events() const1068 bool with_events() const
1069 {
1070 return static_cast<bool>(m_ed_data);
1071 }
1072 void reset_cooldowns();
get_t_events() const1073 const std::vector<t_event_t> &get_t_events() const
1074 {
1075 if (!m_ed_data) {
1076 throw std::invalid_argument("No events were defined for this integrator");
1077 }
1078
1079 return m_ed_data->m_tes;
1080 }
get_te_cooldowns() const1081 const auto &get_te_cooldowns() const
1082 {
1083 if (!m_ed_data) {
1084 throw std::invalid_argument("No events were defined for this integrator");
1085 }
1086
1087 return m_ed_data->m_te_cooldowns;
1088 }
get_nt_events() const1089 const std::vector<nt_event_t> &get_nt_events() const
1090 {
1091 if (!m_ed_data) {
1092 throw std::invalid_argument("No events were defined for this integrator");
1093 }
1094
1095 return m_ed_data->m_ntes;
1096 }
1097
1098 std::tuple<taylor_outcome, T> step(bool = false);
1099 std::tuple<taylor_outcome, T> step_backward(bool = false);
1100 std::tuple<taylor_outcome, T> step(T, bool = false);
1101
1102 private:
1103 // Parser for the common kwargs options for the propagate_*() functions.
1104 template <bool Grid, typename... KwArgs>
propagate_common_ops(KwArgs &&...kw_args)1105 static auto propagate_common_ops(KwArgs &&...kw_args)
1106 {
1107 igor::parser p{kw_args...};
1108
1109 if constexpr (p.has_unnamed_arguments()) {
1110 static_assert(detail::always_false_v<KwArgs...>, "The variadic arguments to a propagate_*() function in an "
1111 "adaptive Taylor integrator contain unnamed arguments.");
1112 throw;
1113 } else {
1114 // Max number of steps (defaults to zero).
1115 auto max_steps = [&p]() -> std::size_t {
1116 if constexpr (p.has(kw::max_steps)) {
1117 return std::forward<decltype(p(kw::max_steps))>(p(kw::max_steps));
1118 } else {
1119 return 0;
1120 }
1121 }();
1122
1123 // Max delta_t (defaults to positive infinity).
1124 auto max_delta_t = [&p]() -> T {
1125 if constexpr (p.has(kw::max_delta_t)) {
1126 return std::forward<decltype(p(kw::max_delta_t))>(p(kw::max_delta_t));
1127 } else {
1128 return std::numeric_limits<T>::infinity();
1129 }
1130 }();
1131
1132 // Callback (defaults to empty).
1133 auto cb = [&p]() -> std::function<bool(taylor_adaptive_impl &)> {
1134 if constexpr (p.has(kw::callback)) {
1135 return std::forward<decltype(p(kw::callback))>(p(kw::callback));
1136 } else {
1137 return {};
1138 }
1139 }();
1140
1141 // Write the Taylor coefficients (defaults to false).
1142 // NOTE: this won't be used in propagate_grid().
1143 auto write_tc = [&p]() -> bool {
1144 if constexpr (p.has(kw::write_tc)) {
1145 return std::forward<decltype(p(kw::write_tc))>(p(kw::write_tc));
1146 } else {
1147 return false;
1148 }
1149 }();
1150
1151 if constexpr (Grid) {
1152 return std::tuple{max_steps, max_delta_t, std::move(cb), write_tc};
1153 } else {
1154 // Continuous output (defaults to false).
1155 auto with_c_out = [&p]() -> bool {
1156 if constexpr (p.has(kw::c_output)) {
1157 return std::forward<decltype(p(kw::c_output))>(p(kw::c_output));
1158 } else {
1159 return false;
1160 }
1161 }();
1162
1163 return std::tuple{max_steps, max_delta_t, std::move(cb), write_tc, with_c_out};
1164 }
1165 }
1166 }
1167
1168 // Implementations of the propagate_*() functions.
1169 std::tuple<taylor_outcome, T, T, std::size_t, std::optional<continuous_output<T>>>
1170 propagate_until_impl(const dfloat<T> &, std::size_t, T, std::function<bool(taylor_adaptive_impl &)>, bool, bool);
1171 std::tuple<taylor_outcome, T, T, std::size_t, std::vector<T>>
1172 propagate_grid_impl(const std::vector<T> &, std::size_t, T, std::function<bool(taylor_adaptive_impl &)>);
1173
1174 public:
1175 // NOTE: return values:
1176 // - outcome,
1177 // - min abs(timestep),
1178 // - max abs(timestep),
1179 // - total number of nonzero steps
1180 // successfully undertaken,
1181 // - grid of state vectors (only for propagate_grid()),
1182 // - continuous output, if requested (only for propagate_for/until()).
1183 // NOTE: the min/max timesteps are well-defined
1184 // only if at least 1-2 steps were taken successfully.
1185 template <typename... KwArgs>
1186 std::tuple<taylor_outcome, T, T, std::size_t, std::optional<continuous_output<T>>>
propagate_until(T t,KwArgs &&...kw_args)1187 propagate_until(T t, KwArgs &&...kw_args)
1188 {
1189 auto [max_steps, max_delta_t, cb, write_tc, with_c_out]
1190 = propagate_common_ops<false>(std::forward<KwArgs>(kw_args)...);
1191
1192 return propagate_until_impl(dfloat<T>(t), max_steps, max_delta_t, std::move(cb), write_tc, with_c_out);
1193 }
1194 template <typename... KwArgs>
1195 std::tuple<taylor_outcome, T, T, std::size_t, std::optional<continuous_output<T>>>
propagate_for(T delta_t,KwArgs &&...kw_args)1196 propagate_for(T delta_t, KwArgs &&...kw_args)
1197 {
1198 auto [max_steps, max_delta_t, cb, write_tc, with_c_out]
1199 = propagate_common_ops<false>(std::forward<KwArgs>(kw_args)...);
1200
1201 return propagate_until_impl(m_time + delta_t, max_steps, max_delta_t, std::move(cb), write_tc, with_c_out);
1202 }
1203 // NOTE: grid is taken by copy because in the implementation loop we keep on reading from it.
1204 // Hence, we need to avoid any aliasing issue with other public integrator data.
1205 template <typename... KwArgs>
propagate_grid(std::vector<T> grid,KwArgs &&...kw_args)1206 std::tuple<taylor_outcome, T, T, std::size_t, std::vector<T>> propagate_grid(std::vector<T> grid,
1207 KwArgs &&...kw_args)
1208 {
1209 auto [max_steps, max_delta_t, cb, _] = propagate_common_ops<true>(std::forward<KwArgs>(kw_args)...);
1210
1211 return propagate_grid_impl(grid, max_steps, max_delta_t, std::move(cb));
1212 }
1213 };
1214
1215 } // namespace detail
1216
1217 template <typename T>
1218 using taylor_adaptive = detail::taylor_adaptive_impl<T>;
1219
1220 namespace detail
1221 {
1222
1223 template <typename T>
1224 class HEYOKA_DLL_PUBLIC taylor_adaptive_batch_impl
1225 {
1226 static_assert(is_supported_fp_v<T>, "Unhandled type.");
1227
1228 public:
1229 using nt_event_t = nt_event_batch<T>;
1230 using t_event_t = t_event_batch<T>;
1231
1232 private:
1233 // Struct implementing the data/logic for event detection.
1234 struct HEYOKA_DLL_PUBLIC ed_data {
1235 // The working list type used during real root isolation.
1236 using wlist_t = std::vector<std::tuple<T, T, taylor_pwrap<T>>>;
1237 // The type used to store the list of isolating intervals.
1238 using isol_t = std::vector<std::tuple<T, T>>;
1239 // Polynomial translation function type.
1240 using pt_t = void (*)(T *, const T *);
1241 // rtscc function type.
1242 using rtscc_t = void (*)(T *, T *, std::uint32_t *, const T *);
1243 // fex_check function type.
1244 using fex_check_t = void (*)(const T *, const T *, const std::uint32_t *, std::uint32_t *);
1245
1246 // The vector of terminal events.
1247 std::vector<t_event_t> m_tes;
1248 // The vector of non-terminal events.
1249 std::vector<nt_event_t> m_ntes;
1250 // The jet of derivatives for the state variables
1251 // and the events.
1252 std::vector<T> m_ev_jet;
1253 // The vector to store the norm infinity of the state
1254 // vector when using the stepper with events.
1255 std::vector<T> m_max_abs_state;
1256 // The vector to store the the maximum absolute error
1257 // on the Taylor series of the event equations.
1258 std::vector<T> m_g_eps;
1259 // Vector of detected terminal events.
1260 std::vector<std::vector<std::tuple<std::uint32_t, T, bool, int, T>>> m_d_tes;
1261 // The vector of cooldowns for the terminal events.
1262 // If an event is on cooldown, the corresponding optional
1263 // in this vector will contain the total time elapsed
1264 // since the cooldown started and the absolute value
1265 // of the cooldown duration.
1266 std::vector<std::vector<std::optional<std::pair<T, T>>>> m_te_cooldowns;
1267 // Vector of detected non-terminal events.
1268 std::vector<std::vector<std::tuple<std::uint32_t, T, int>>> m_d_ntes;
1269 // The LLVM state.
1270 llvm_state m_state;
1271 // Flags to signal if we are integrating backwards in time.
1272 std::vector<std::uint32_t> m_back_int;
1273 // Output of the fast exclusion check.
1274 std::vector<std::uint32_t> m_fex_check_res;
1275 // The JIT compiled functions used during root finding.
1276 // NOTE: use default member initializers to ensure that
1277 // these are zero-inited by the default constructor
1278 // (which is defaulted).
1279 pt_t m_pt = nullptr;
1280 rtscc_t m_rtscc = nullptr;
1281 fex_check_t m_fex_check = nullptr;
1282 // The working list.
1283 wlist_t m_wlist;
1284 // The list of isolating intervals.
1285 isol_t m_isol;
1286 // The polynomial cache.
1287 taylor_poly_cache<T> m_poly_cache;
1288
1289 // Constructors.
1290 ed_data(std::vector<t_event_t>, std::vector<nt_event_t>, std::uint32_t, std::uint32_t, std::uint32_t);
1291 ed_data(const ed_data &);
1292 ~ed_data();
1293
1294 // Delete unused bits.
1295 ed_data(ed_data &&) = delete;
1296 ed_data &operator=(const ed_data &) = delete;
1297 ed_data &operator=(ed_data &&) = delete;
1298
1299 // The event detection function.
1300 void detect_events(const T *, std::uint32_t, std::uint32_t, std::uint32_t);
1301
1302 private:
1303 // Serialisation.
1304 // NOTE: the def ctor is used only during deserialisation
1305 // via pointer.
1306 ed_data();
1307 friend class boost::serialization::access;
1308 void save(boost::archive::binary_oarchive &, unsigned) const;
1309 void load(boost::archive::binary_iarchive &, unsigned);
1310 BOOST_SERIALIZATION_SPLIT_MEMBER()
1311 };
1312
1313 // The batch size.
1314 std::uint32_t m_batch_size;
1315 // State vectors.
1316 std::vector<T> m_state;
1317 // Times.
1318 std::vector<T> m_time_hi, m_time_lo;
1319 // The LLVM machinery.
1320 llvm_state m_llvm;
1321 // Dimension of the system.
1322 std::uint32_t m_dim;
1323 // Taylor decomposition.
1324 taylor_dc_t m_dc;
1325 // Taylor order.
1326 std::uint32_t m_order;
1327 // Tolerance.
1328 T m_tol;
1329 // High accuracy.
1330 bool m_high_accuracy;
1331 // Compact mode.
1332 bool m_compact_mode;
1333 // The steppers.
1334 using step_f_t = void (*)(T *, const T *, const T *, T *, T *);
1335 using step_f_e_t = void (*)(T *, const T *, const T *, const T *, T *, T *);
1336 std::variant<step_f_t, step_f_e_t> m_step_f;
1337 // The vector of parameters.
1338 std::vector<T> m_pars;
1339 // The vector for the Taylor coefficients.
1340 std::vector<T> m_tc;
1341 // The sizes of the last timesteps taken.
1342 std::vector<T> m_last_h;
1343 // The function for computing the dense output.
1344 using d_out_f_t = void (*)(T *, const T *, const T *);
1345 d_out_f_t m_d_out_f;
1346 // The vector for the dense output.
1347 std::vector<T> m_d_out;
1348 // Temporary vectors for use
1349 // in the timestepping functions.
1350 // These two are used as default values,
1351 // they must never be modified.
1352 std::vector<T> m_pinf, m_minf;
1353 // This is used as temporary storage in step_impl().
1354 std::vector<T> m_delta_ts;
1355 // The vectors used to store the results of the step
1356 // and propagate functions.
1357 std::vector<std::tuple<taylor_outcome, T>> m_step_res;
1358 std::vector<std::tuple<taylor_outcome, T, T, std::size_t>> m_prop_res;
1359 // Temporary vectors used in the propagate_*() implementations.
1360 std::vector<std::size_t> m_ts_count;
1361 std::vector<T> m_min_abs_h, m_max_abs_h;
1362 std::vector<T> m_cur_max_delta_ts;
1363 std::vector<dfloat<T>> m_pfor_ts;
1364 std::vector<int> m_t_dir;
1365 std::vector<dfloat<T>> m_rem_time;
1366 // Temporary vector used in the dense output implementation.
1367 std::vector<T> m_d_out_time;
1368 // Auxiliary data/functions for event detection.
1369 std::unique_ptr<ed_data> m_ed_data;
1370
1371 // Serialization.
1372 template <typename Archive>
1373 HEYOKA_DLL_LOCAL void save_impl(Archive &, unsigned) const;
1374 template <typename Archive>
1375 HEYOKA_DLL_LOCAL void load_impl(Archive &, unsigned);
1376
1377 friend class boost::serialization::access;
1378 void save(boost::archive::binary_oarchive &, unsigned) const;
1379 void load(boost::archive::binary_iarchive &, unsigned);
1380 BOOST_SERIALIZATION_SPLIT_MEMBER()
1381
1382 HEYOKA_DLL_LOCAL void step_impl(const std::vector<T> &, bool);
1383
1384 // Private implementation-detail constructor machinery.
1385 template <typename U>
1386 HEYOKA_DLL_PUBLIC void finalise_ctor_impl(const U &, std::vector<T>, std::uint32_t, std::vector<T>, T, bool, bool,
1387 std::vector<T>, std::vector<t_event_t>, std::vector<nt_event_t>);
1388 template <typename U, typename... KwArgs>
finalise_ctor(const U & sys,std::vector<T> state,std::uint32_t batch_size,KwArgs &&...kw_args)1389 void finalise_ctor(const U &sys, std::vector<T> state, std::uint32_t batch_size, KwArgs &&...kw_args)
1390 {
1391 igor::parser p{kw_args...};
1392
1393 if constexpr (p.has_unnamed_arguments()) {
1394 static_assert(detail::always_false_v<KwArgs...>,
1395 "The variadic arguments in the construction of an adaptive batch Taylor integrator contain "
1396 "unnamed arguments.");
1397 } else {
1398 // Initial times (defaults to a vector of zeroes).
1399 auto time = [&p, batch_size]() -> std::vector<T> {
1400 if constexpr (p.has(kw::time)) {
1401 return std::forward<decltype(p(kw::time))>(p(kw::time));
1402 } else {
1403 return std::vector<T>(static_cast<typename std::vector<T>::size_type>(batch_size), T(0));
1404 }
1405 }();
1406
1407 auto [high_accuracy, tol, compact_mode, pars]
1408 = taylor_adaptive_common_ops<T>(std::forward<KwArgs>(kw_args)...);
1409
1410 // Extract the terminal events, if any.
1411 auto tes = [&p]() -> std::vector<t_event_t> {
1412 if constexpr (p.has(kw::t_events)) {
1413 return std::forward<decltype(p(kw::t_events))>(p(kw::t_events));
1414 } else {
1415 return {};
1416 }
1417 }();
1418
1419 // Extract the non-terminal events, if any.
1420 auto ntes = [&p]() -> std::vector<nt_event_t> {
1421 if constexpr (p.has(kw::nt_events)) {
1422 return std::forward<decltype(p(kw::nt_events))>(p(kw::nt_events));
1423 } else {
1424 return {};
1425 }
1426 }();
1427
1428 finalise_ctor_impl(sys, std::move(state), batch_size, std::move(time), tol, high_accuracy, compact_mode,
1429 std::move(pars), std::move(tes), std::move(ntes));
1430 }
1431 }
1432
1433 public:
1434 taylor_adaptive_batch_impl();
1435
1436 template <typename... KwArgs>
taylor_adaptive_batch_impl(const std::vector<expression> & sys,std::vector<T> state,std::uint32_t batch_size,KwArgs &&...kw_args)1437 explicit taylor_adaptive_batch_impl(const std::vector<expression> &sys, std::vector<T> state,
1438 std::uint32_t batch_size, KwArgs &&...kw_args)
1439 : m_llvm{std::forward<KwArgs>(kw_args)...}
1440 {
1441 finalise_ctor(sys, std::move(state), batch_size, std::forward<KwArgs>(kw_args)...);
1442 }
1443 template <typename... KwArgs>
taylor_adaptive_batch_impl(const std::vector<std::pair<expression,expression>> & sys,std::vector<T> state,std::uint32_t batch_size,KwArgs &&...kw_args)1444 explicit taylor_adaptive_batch_impl(const std::vector<std::pair<expression, expression>> &sys, std::vector<T> state,
1445 std::uint32_t batch_size, KwArgs &&...kw_args)
1446 : m_llvm{std::forward<KwArgs>(kw_args)...}
1447 {
1448 finalise_ctor(sys, std::move(state), batch_size, std::forward<KwArgs>(kw_args)...);
1449 }
1450
1451 taylor_adaptive_batch_impl(const taylor_adaptive_batch_impl &);
1452 taylor_adaptive_batch_impl(taylor_adaptive_batch_impl &&) noexcept;
1453
1454 taylor_adaptive_batch_impl &operator=(const taylor_adaptive_batch_impl &);
1455 taylor_adaptive_batch_impl &operator=(taylor_adaptive_batch_impl &&) noexcept;
1456
1457 ~taylor_adaptive_batch_impl();
1458
1459 const llvm_state &get_llvm_state() const;
1460
1461 const taylor_dc_t &get_decomposition() const;
1462
1463 std::uint32_t get_batch_size() const;
1464 std::uint32_t get_order() const;
1465 T get_tol() const;
1466 bool get_high_accuracy() const;
1467 bool get_compact_mode() const;
1468 std::uint32_t get_dim() const;
1469
get_time() const1470 const std::vector<T> &get_time() const
1471 {
1472 return m_time_hi;
1473 }
get_time_data() const1474 const T *get_time_data() const
1475 {
1476 return m_time_hi.data();
1477 }
1478 void set_time(const std::vector<T> &);
1479
get_state() const1480 const std::vector<T> &get_state() const
1481 {
1482 return m_state;
1483 }
get_state_data() const1484 const T *get_state_data() const
1485 {
1486 return m_state.data();
1487 }
get_state_data()1488 T *get_state_data()
1489 {
1490 return m_state.data();
1491 }
1492
get_pars() const1493 const std::vector<T> &get_pars() const
1494 {
1495 return m_pars;
1496 }
get_pars_data() const1497 const T *get_pars_data() const
1498 {
1499 return m_pars.data();
1500 }
get_pars_data()1501 T *get_pars_data()
1502 {
1503 return m_pars.data();
1504 }
1505
get_tc() const1506 const std::vector<T> &get_tc() const
1507 {
1508 return m_tc;
1509 }
1510
get_last_h() const1511 const std::vector<T> &get_last_h() const
1512 {
1513 return m_last_h;
1514 }
1515
get_d_output() const1516 const std::vector<T> &get_d_output() const
1517 {
1518 return m_d_out;
1519 }
1520 const std::vector<T> &update_d_output(const std::vector<T> &, bool = false);
1521
with_events() const1522 bool with_events() const
1523 {
1524 return static_cast<bool>(m_ed_data);
1525 }
1526 void reset_cooldowns();
1527 void reset_cooldowns(std::uint32_t);
get_t_events() const1528 const std::vector<t_event_t> &get_t_events() const
1529 {
1530 if (!m_ed_data) {
1531 throw std::invalid_argument("No events were defined for this integrator");
1532 }
1533
1534 return m_ed_data->m_tes;
1535 }
get_te_cooldowns() const1536 const auto &get_te_cooldowns() const
1537 {
1538 if (!m_ed_data) {
1539 throw std::invalid_argument("No events were defined for this integrator");
1540 }
1541
1542 return m_ed_data->m_te_cooldowns;
1543 }
get_nt_events() const1544 const std::vector<nt_event_t> &get_nt_events() const
1545 {
1546 if (!m_ed_data) {
1547 throw std::invalid_argument("No events were defined for this integrator");
1548 }
1549
1550 return m_ed_data->m_ntes;
1551 }
1552
1553 void step(bool = false);
1554 void step_backward(bool = false);
1555 void step(const std::vector<T> &, bool = false);
get_step_res() const1556 const std::vector<std::tuple<taylor_outcome, T>> &get_step_res() const
1557 {
1558 return m_step_res;
1559 }
1560
1561 private:
1562 // Parser for the common kwargs options for the propagate_*() functions.
1563 template <bool Grid, typename... KwArgs>
propagate_common_ops(KwArgs &&...kw_args) const1564 auto propagate_common_ops(KwArgs &&...kw_args) const
1565 {
1566 igor::parser p{kw_args...};
1567
1568 if constexpr (p.has_unnamed_arguments()) {
1569 static_assert(detail::always_false_v<KwArgs...>,
1570 "The variadic arguments to a propagate_*() function in an "
1571 "adaptive Taylor integrator in batch mode contain unnamed arguments.");
1572 throw;
1573 } else {
1574 // Max number of steps (defaults to zero).
1575 auto max_steps = [&p]() -> std::size_t {
1576 if constexpr (p.has(kw::max_steps)) {
1577 return std::forward<decltype(p(kw::max_steps))>(p(kw::max_steps));
1578 } else {
1579 return 0;
1580 }
1581 }();
1582
1583 // Max delta_t (defaults to empty vector).
1584 // NOTE: we want an explicit copy here because
1585 // in the implementations of the propagate_*() functions
1586 // we keep on checking on max_delta_t before invoking
1587 // the single step function. Hence, we want to avoid
1588 // any risk of aliasing.
1589 auto max_delta_t = [&p]() -> std::vector<T> {
1590 if constexpr (p.has(kw::max_delta_t)) {
1591 return std::forward<decltype(p(kw::max_delta_t))>(p(kw::max_delta_t));
1592 } else {
1593 return {};
1594 }
1595 }();
1596
1597 // Callback (defaults to empty).
1598 auto cb = [&p]() -> std::function<bool(taylor_adaptive_batch_impl &)> {
1599 if constexpr (p.has(kw::callback)) {
1600 return std::forward<decltype(p(kw::callback))>(p(kw::callback));
1601 } else {
1602 return {};
1603 }
1604 }();
1605
1606 // Write the Taylor coefficients (defaults to false).
1607 // NOTE: this won't be used in propagate_grid().
1608 auto write_tc = [&p]() -> bool {
1609 if constexpr (p.has(kw::write_tc)) {
1610 return std::forward<decltype(p(kw::write_tc))>(p(kw::write_tc));
1611 } else {
1612 return false;
1613 }
1614 }();
1615
1616 if constexpr (Grid) {
1617 return std::tuple{max_steps, std::move(max_delta_t), std::move(cb), write_tc};
1618 } else {
1619 // Continuous output (defaults to false).
1620 auto with_c_out = [&p]() -> bool {
1621 if constexpr (p.has(kw::c_output)) {
1622 return std::forward<decltype(p(kw::c_output))>(p(kw::c_output));
1623 } else {
1624 return false;
1625 }
1626 }();
1627
1628 return std::tuple{max_steps, std::move(max_delta_t), std::move(cb), write_tc, with_c_out};
1629 }
1630 }
1631 }
1632
1633 // Implementations of the propagate_*() functions.
1634 HEYOKA_DLL_LOCAL std::optional<continuous_output_batch<T>>
1635 propagate_until_impl(const std::vector<dfloat<T>> &, std::size_t, const std::vector<T> &,
1636 std::function<bool(taylor_adaptive_batch_impl &)>, bool, bool);
1637 std::optional<continuous_output_batch<T>> propagate_until_impl(const std::vector<T> &, std::size_t,
1638 const std::vector<T> &,
1639 std::function<bool(taylor_adaptive_batch_impl &)>,
1640 bool, bool);
1641 std::optional<continuous_output_batch<T>> propagate_for_impl(const std::vector<T> &, std::size_t,
1642 const std::vector<T> &,
1643 std::function<bool(taylor_adaptive_batch_impl &)>,
1644 bool, bool);
1645 std::vector<T> propagate_grid_impl(const std::vector<T> &, std::size_t, const std::vector<T> &,
1646 std::function<bool(taylor_adaptive_batch_impl &)>);
1647
1648 public:
1649 // NOTE: in propagate_for/until(), we can take 'ts' as const reference because it is always
1650 // only and immediately used to set up the internal m_pfor_ts member (which is not visible
1651 // from outside). Hence, even if 'ts' aliases some public integrator data, it does not matter.
1652 template <typename... KwArgs>
propagate_until(const std::vector<T> & ts,KwArgs &&...kw_args)1653 std::optional<continuous_output_batch<T>> propagate_until(const std::vector<T> &ts, KwArgs &&...kw_args)
1654 {
1655 auto [max_steps, max_delta_ts, cb, write_tc, with_c_out]
1656 = propagate_common_ops<false>(std::forward<KwArgs>(kw_args)...);
1657
1658 return propagate_until_impl(ts, max_steps, max_delta_ts.empty() ? m_pinf : max_delta_ts, std::move(cb),
1659 write_tc, with_c_out); // LCOV_EXCL_LINE
1660 }
1661 template <typename... KwArgs>
propagate_for(const std::vector<T> & ts,KwArgs &&...kw_args)1662 std::optional<continuous_output_batch<T>> propagate_for(const std::vector<T> &ts, KwArgs &&...kw_args)
1663 {
1664 auto [max_steps, max_delta_ts, cb, write_tc, with_c_out]
1665 = propagate_common_ops<false>(std::forward<KwArgs>(kw_args)...);
1666
1667 return propagate_for_impl(ts, max_steps, max_delta_ts.empty() ? m_pinf : max_delta_ts, std::move(cb), write_tc,
1668 with_c_out); // LCOV_EXCL_LINE
1669 }
1670 // NOTE: grid is taken by copy because in the implementation loop we keep on reading from it.
1671 // Hence, we need to avoid any aliasing issue with other public integrator data.
1672 template <typename... KwArgs>
propagate_grid(std::vector<T> grid,KwArgs &&...kw_args)1673 std::vector<T> propagate_grid(std::vector<T> grid, KwArgs &&...kw_args)
1674 {
1675 auto [max_steps, max_delta_ts, cb, _] = propagate_common_ops<true>(std::forward<KwArgs>(kw_args)...);
1676
1677 return propagate_grid_impl(grid, max_steps, max_delta_ts.empty() ? m_pinf : max_delta_ts, std::move(cb));
1678 }
get_propagate_res() const1679 const std::vector<std::tuple<taylor_outcome, T, T, std::size_t>> &get_propagate_res() const
1680 {
1681 return m_prop_res;
1682 }
1683 };
1684
1685 } // namespace detail
1686
1687 template <typename T>
1688 using taylor_adaptive_batch = detail::taylor_adaptive_batch_impl<T>;
1689
1690 namespace detail
1691 {
1692
1693 template <typename T>
operator <<(std::ostream & os,const taylor_adaptive_impl<T> &)1694 inline std::ostream &operator<<(std::ostream &os, const taylor_adaptive_impl<T> &)
1695 {
1696 static_assert(always_false_v<T>, "Unhandled type.");
1697
1698 return os;
1699 }
1700
1701 template <>
1702 HEYOKA_DLL_PUBLIC std::ostream &operator<<(std::ostream &, const taylor_adaptive_impl<double> &);
1703
1704 template <>
1705 HEYOKA_DLL_PUBLIC std::ostream &operator<<(std::ostream &, const taylor_adaptive_impl<long double> &);
1706
1707 #if defined(HEYOKA_HAVE_REAL128)
1708
1709 template <>
1710 HEYOKA_DLL_PUBLIC std::ostream &operator<<(std::ostream &, const taylor_adaptive_impl<mppp::real128> &);
1711
1712 #endif
1713
1714 template <typename T>
operator <<(std::ostream & os,const taylor_adaptive_batch_impl<T> &)1715 inline std::ostream &operator<<(std::ostream &os, const taylor_adaptive_batch_impl<T> &)
1716 {
1717 static_assert(always_false_v<T>, "Unhandled type.");
1718
1719 return os;
1720 }
1721
1722 template <>
1723 HEYOKA_DLL_PUBLIC std::ostream &operator<<(std::ostream &, const taylor_adaptive_batch_impl<double> &);
1724
1725 template <>
1726 HEYOKA_DLL_PUBLIC std::ostream &operator<<(std::ostream &, const taylor_adaptive_batch_impl<long double> &);
1727
1728 #if defined(HEYOKA_HAVE_REAL128)
1729
1730 template <>
1731 HEYOKA_DLL_PUBLIC std::ostream &operator<<(std::ostream &, const taylor_adaptive_batch_impl<mppp::real128> &);
1732
1733 #endif
1734
1735 } // namespace detail
1736
1737 } // namespace heyoka
1738
1739 // NOTE: copy the implementation of the BOOST_CLASS_VERSION macro, as it does
1740 // not support class templates.
1741 namespace boost::serialization
1742 {
1743
1744 template <typename T>
1745 struct version<heyoka::taylor_adaptive<T>> {
1746 typedef mpl::int_<2> type;
1747 typedef mpl::integral_c_tag tag;
1748 BOOST_STATIC_CONSTANT(int, value = version::type::value);
1749 BOOST_MPL_ASSERT((boost::mpl::less<boost::mpl::int_<2>, boost::mpl::int_<256>>));
1750 };
1751
1752 template <typename T>
1753 struct version<heyoka::taylor_adaptive_batch<T>> {
1754 typedef mpl::int_<2> type;
1755 typedef mpl::integral_c_tag tag;
1756 BOOST_STATIC_CONSTANT(int, value = version::type::value);
1757 BOOST_MPL_ASSERT((boost::mpl::less<boost::mpl::int_<2>, boost::mpl::int_<256>>));
1758 };
1759
1760 } // namespace boost::serialization
1761
1762 #endif
1763