1 // Copyright 2020, 2021 Francesco Biscani (bluescarni@gmail.com), Dario Izzo (dario.izzo@gmail.com)
2 //
3 // This file is part of the heyoka library.
4 //
5 // This Source Code Form is subject to the terms of the Mozilla
6 // Public License v. 2.0. If a copy of the MPL was not distributed
7 // with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
8 
9 #ifndef HEYOKA_FUNC_HPP
10 #define HEYOKA_FUNC_HPP
11 
12 #include <heyoka/config.hpp>
13 
14 #include <cstddef>
15 #include <cstdint>
16 #include <memory>
17 #include <ostream>
18 #include <stdexcept>
19 #include <string>
20 #include <type_traits>
21 #include <typeindex>
22 #include <typeinfo>
23 #include <unordered_map>
24 #include <unordered_set>
25 #include <utility>
26 #include <vector>
27 
28 #if defined(HEYOKA_HAVE_REAL128)
29 
30 #include <mp++/real128.hpp>
31 
32 #endif
33 
34 #include <heyoka/detail/fwd_decl.hpp>
35 #include <heyoka/detail/llvm_fwd.hpp>
36 #include <heyoka/detail/type_traits.hpp>
37 #include <heyoka/detail/visibility.hpp>
38 #include <heyoka/exceptions.hpp>
39 #include <heyoka/s11n.hpp>
40 
41 namespace heyoka
42 {
43 
44 class HEYOKA_DLL_PUBLIC func_base
45 {
46     std::string m_name;
47     std::vector<expression> m_args;
48 
49     // Serialization.
50     friend class boost::serialization::access;
51     template <typename Archive>
serialize(Archive & ar,unsigned)52     void serialize(Archive &ar, unsigned)
53     {
54         ar &m_name;
55         ar &m_args;
56     }
57 
58 public:
59     explicit func_base(std::string, std::vector<expression>);
60 
61     func_base(const func_base &);
62     func_base(func_base &&) noexcept;
63 
64     func_base &operator=(const func_base &);
65     func_base &operator=(func_base &&) noexcept;
66 
67     ~func_base();
68 
69     const std::string &get_name() const;
70     const std::vector<expression> &args() const;
71     std::pair<std::vector<expression>::iterator, std::vector<expression>::iterator> get_mutable_args_it();
72 };
73 
74 namespace detail
75 {
76 
77 struct HEYOKA_DLL_PUBLIC func_inner_base {
78     virtual ~func_inner_base();
79     virtual std::unique_ptr<func_inner_base> clone() const = 0;
80 
81     virtual std::type_index get_type_index() const = 0;
82     virtual const void *get_ptr() const = 0;
83     virtual void *get_ptr() = 0;
84 
85     virtual const std::string &get_name() const = 0;
86 
87     virtual void to_stream(std::ostream &) const = 0;
88 
89     virtual bool extra_equal_to(const func &) const = 0;
90 
91     virtual std::size_t extra_hash() const = 0;
92 
93     virtual const std::vector<expression> &args() const = 0;
94     virtual std::pair<std::vector<expression>::iterator, std::vector<expression>::iterator> get_mutable_args_it() = 0;
95 
96     virtual llvm::Value *codegen_dbl(llvm_state &, const std::vector<llvm::Value *> &) const = 0;
97     virtual llvm::Value *codegen_ldbl(llvm_state &, const std::vector<llvm::Value *> &) const = 0;
98 #if defined(HEYOKA_HAVE_REAL128)
99     virtual llvm::Value *codegen_f128(llvm_state &, const std::vector<llvm::Value *> &) const = 0;
100 #endif
101 
102     virtual bool has_diff_var() const = 0;
103     virtual expression diff(std::unordered_map<const void *, expression> &, const std::string &) const = 0;
104     virtual bool has_diff_par() const = 0;
105     virtual expression diff(std::unordered_map<const void *, expression> &, const param &) const = 0;
106     virtual bool has_gradient() const = 0;
107     virtual std::vector<expression> gradient() const = 0;
108 
109     virtual double eval_dbl(const std::unordered_map<std::string, double> &, const std::vector<double> &) const = 0;
110     virtual long double eval_ldbl(const std::unordered_map<std::string, long double> &,
111                                   const std::vector<long double> &) const = 0;
112 #if defined(HEYOKA_HAVE_REAL128)
113     virtual mppp::real128 eval_f128(const std::unordered_map<std::string, mppp::real128> &,
114                                     const std::vector<mppp::real128> &) const = 0;
115 #endif
116 
117     virtual void eval_batch_dbl(std::vector<double> &, const std::unordered_map<std::string, std::vector<double>> &,
118                                 const std::vector<double> &) const = 0;
119     virtual double eval_num_dbl(const std::vector<double> &) const = 0;
120     virtual double deval_num_dbl(const std::vector<double> &, std::vector<double>::size_type) const = 0;
121 
122     virtual taylor_dc_t::size_type taylor_decompose(taylor_dc_t &) && = 0;
123     virtual bool has_taylor_decompose() const = 0;
124     virtual llvm::Value *taylor_diff_dbl(llvm_state &, const std::vector<std::uint32_t> &,
125                                          const std::vector<llvm::Value *> &, llvm::Value *, llvm::Value *,
126                                          std::uint32_t, std::uint32_t, std::uint32_t, std::uint32_t, bool) const = 0;
127     virtual llvm::Value *taylor_diff_ldbl(llvm_state &, const std::vector<std::uint32_t> &,
128                                           const std::vector<llvm::Value *> &, llvm::Value *, llvm::Value *,
129                                           std::uint32_t, std::uint32_t, std::uint32_t, std::uint32_t, bool) const = 0;
130 #if defined(HEYOKA_HAVE_REAL128)
131     virtual llvm::Value *taylor_diff_f128(llvm_state &, const std::vector<std::uint32_t> &,
132                                           const std::vector<llvm::Value *> &, llvm::Value *, llvm::Value *,
133                                           std::uint32_t, std::uint32_t, std::uint32_t, std::uint32_t, bool) const = 0;
134 #endif
135     virtual llvm::Function *taylor_c_diff_func_dbl(llvm_state &, std::uint32_t, std::uint32_t, bool) const = 0;
136     virtual llvm::Function *taylor_c_diff_func_ldbl(llvm_state &, std::uint32_t, std::uint32_t, bool) const = 0;
137 #if defined(HEYOKA_HAVE_REAL128)
138     virtual llvm::Function *taylor_c_diff_func_f128(llvm_state &, std::uint32_t, std::uint32_t, bool) const = 0;
139 #endif
140 
141 private:
142     // Serialization.
143     friend class boost::serialization::access;
144     template <typename Archive>
serializeheyoka::detail::func_inner_base145     void serialize(Archive &, unsigned)
146     {
147     }
148 };
149 
150 template <typename T>
151 using func_to_stream_t
152     = decltype(std::declval<std::add_lvalue_reference_t<const T>>().to_stream(std::declval<std::ostream &>()));
153 
154 template <typename T>
155 inline constexpr bool func_has_to_stream_v = std::is_same_v<detected_t<func_to_stream_t, T>, void>;
156 
157 template <typename T>
158 using func_extra_equal_to_t
159     = decltype(std::declval<std::add_lvalue_reference_t<const T>>().extra_equal_to(std::declval<const func &>()));
160 
161 template <typename T>
162 inline constexpr bool func_has_extra_equal_to_v = std::is_same_v<detected_t<func_extra_equal_to_t, T>, bool>;
163 
164 template <typename T>
165 using func_extra_hash_t = decltype(std::declval<std::add_lvalue_reference_t<const T>>().extra_hash());
166 
167 template <typename T>
168 inline constexpr bool func_has_extra_hash_v = std::is_same_v<detected_t<func_extra_hash_t, T>, std::size_t>;
169 
170 template <typename T>
171 using func_codegen_dbl_t = decltype(std::declval<std::add_lvalue_reference_t<const T>>().codegen_dbl(
172     std::declval<llvm_state &>(), std::declval<const std::vector<llvm::Value *> &>()));
173 
174 template <typename T>
175 inline constexpr bool func_has_codegen_dbl_v = std::is_same_v<detected_t<func_codegen_dbl_t, T>, llvm::Value *>;
176 
177 template <typename T>
178 using func_codegen_ldbl_t = decltype(std::declval<std::add_lvalue_reference_t<const T>>().codegen_ldbl(
179     std::declval<llvm_state &>(), std::declval<const std::vector<llvm::Value *> &>()));
180 
181 template <typename T>
182 inline constexpr bool func_has_codegen_ldbl_v = std::is_same_v<detected_t<func_codegen_ldbl_t, T>, llvm::Value *>;
183 
184 #if defined(HEYOKA_HAVE_REAL128)
185 
186 template <typename T>
187 using func_codegen_f128_t = decltype(std::declval<std::add_lvalue_reference_t<const T>>().codegen_f128(
188     std::declval<llvm_state &>(), std::declval<const std::vector<llvm::Value *> &>()));
189 
190 template <typename T>
191 inline constexpr bool func_has_codegen_f128_v = std::is_same_v<detected_t<func_codegen_f128_t, T>, llvm::Value *>;
192 
193 #endif
194 
195 template <typename T>
196 using func_diff_var_t = decltype(std::declval<std::add_lvalue_reference_t<const T>>().diff(
197     std::declval<std::unordered_map<const void *, expression> &>(), std::declval<const std::string &>()));
198 
199 template <typename T>
200 inline constexpr bool func_has_diff_var_v = std::is_same_v<detected_t<func_diff_var_t, T>, expression>;
201 
202 template <typename T>
203 using func_diff_par_t = decltype(std::declval<std::add_lvalue_reference_t<const T>>().diff(
204     std::declval<std::unordered_map<const void *, expression> &>(), std::declval<const param &>()));
205 
206 template <typename T>
207 inline constexpr bool func_has_diff_par_v = std::is_same_v<detected_t<func_diff_par_t, T>, expression>;
208 
209 template <typename T>
210 using func_gradient_t = decltype(std::declval<std::add_lvalue_reference_t<const T>>().gradient());
211 
212 template <typename T>
213 inline constexpr bool func_has_gradient_v = std::is_same_v<detected_t<func_gradient_t, T>, std::vector<expression>>;
214 
215 template <typename T>
216 using func_eval_dbl_t = decltype(std::declval<std::add_lvalue_reference_t<const T>>().eval_dbl(
217     std::declval<const std::unordered_map<std::string, double> &>(), std::declval<const std::vector<double> &>()));
218 
219 template <typename T>
220 inline constexpr bool func_has_eval_dbl_v = std::is_same_v<detected_t<func_eval_dbl_t, T>, double>;
221 
222 template <typename T>
223 using func_eval_ldbl_t = decltype(std::declval<std::add_lvalue_reference_t<const T>>().eval_ldbl(
224     std::declval<const std::unordered_map<std::string, long double> &>(),
225     std::declval<const std::vector<long double> &>()));
226 
227 template <typename T>
228 inline constexpr bool func_has_eval_ldbl_v = std::is_same_v<detected_t<func_eval_ldbl_t, T>, long double>;
229 
230 #if defined(HEYOKA_HAVE_REAL128)
231 template <typename T>
232 using func_eval_f128_t = decltype(std::declval<std::add_lvalue_reference_t<const T>>().eval_f128(
233     std::declval<const std::unordered_map<std::string, mppp::real128> &>(),
234     std::declval<const std::vector<mppp::real128> &>()));
235 
236 template <typename T>
237 inline constexpr bool func_has_eval_f128_v = std::is_same_v<detected_t<func_eval_f128_t, T>, mppp::real128>;
238 #endif
239 
240 template <typename T>
241 using func_eval_batch_dbl_t = decltype(std::declval<std::add_lvalue_reference_t<const T>>().eval_batch_dbl(
242     std::declval<std::vector<double> &>(), std::declval<const std::unordered_map<std::string, std::vector<double>> &>(),
243     std::declval<const std::vector<double> &>()));
244 
245 template <typename T>
246 inline constexpr bool func_has_eval_batch_dbl_v = std::is_same_v<detected_t<func_eval_batch_dbl_t, T>, void>;
247 
248 template <typename T>
249 using func_eval_num_dbl_t = decltype(std::declval<std::add_lvalue_reference_t<const T>>().eval_num_dbl(
250     std::declval<const std::vector<double> &>()));
251 
252 template <typename T>
253 inline constexpr bool func_has_eval_num_dbl_v = std::is_same_v<detected_t<func_eval_num_dbl_t, T>, double>;
254 
255 template <typename T>
256 using func_deval_num_dbl_t = decltype(std::declval<std::add_lvalue_reference_t<const T>>().deval_num_dbl(
257     std::declval<const std::vector<double> &>(), std::declval<std::vector<double>::size_type>()));
258 
259 template <typename T>
260 inline constexpr bool func_has_deval_num_dbl_v = std::is_same_v<detected_t<func_deval_num_dbl_t, T>, double>;
261 
262 template <typename T>
263 using func_taylor_decompose_t
264     = decltype(std::declval<std::add_rvalue_reference_t<T>>().taylor_decompose(std::declval<taylor_dc_t &>()));
265 
266 template <typename T>
267 inline constexpr bool func_has_taylor_decompose_v
268     = std::is_same_v<detected_t<func_taylor_decompose_t, T>, taylor_dc_t::size_type>;
269 
270 template <typename T>
271 using func_taylor_diff_dbl_t = decltype(std::declval<std::add_lvalue_reference_t<const T>>().taylor_diff_dbl(
272     std::declval<llvm_state &>(), std::declval<const std::vector<std::uint32_t> &>(),
273     std::declval<const std::vector<llvm::Value *> &>(), std::declval<llvm::Value *>(), std::declval<llvm::Value *>(),
274     std::declval<std::uint32_t>(), std::declval<std::uint32_t>(), std::declval<std::uint32_t>(),
275     std::declval<std::uint32_t>(), std::declval<bool>()));
276 
277 template <typename T>
278 inline constexpr bool func_has_taylor_diff_dbl_v = std::is_same_v<detected_t<func_taylor_diff_dbl_t, T>, llvm::Value *>;
279 
280 template <typename T>
281 using func_taylor_diff_ldbl_t = decltype(std::declval<std::add_lvalue_reference_t<const T>>().taylor_diff_ldbl(
282     std::declval<llvm_state &>(), std::declval<const std::vector<std::uint32_t> &>(),
283     std::declval<const std::vector<llvm::Value *> &>(), std::declval<llvm::Value *>(), std::declval<llvm::Value *>(),
284     std::declval<std::uint32_t>(), std::declval<std::uint32_t>(), std::declval<std::uint32_t>(),
285     std::declval<std::uint32_t>(), std::declval<bool>()));
286 
287 template <typename T>
288 inline constexpr bool func_has_taylor_diff_ldbl_v
289     = std::is_same_v<detected_t<func_taylor_diff_ldbl_t, T>, llvm::Value *>;
290 
291 #if defined(HEYOKA_HAVE_REAL128)
292 
293 template <typename T>
294 using func_taylor_diff_f128_t = decltype(std::declval<std::add_lvalue_reference_t<const T>>().taylor_diff_f128(
295     std::declval<llvm_state &>(), std::declval<const std::vector<std::uint32_t> &>(),
296     std::declval<const std::vector<llvm::Value *> &>(), std::declval<llvm::Value *>(), std::declval<llvm::Value *>(),
297     std::declval<std::uint32_t>(), std::declval<std::uint32_t>(), std::declval<std::uint32_t>(),
298     std::declval<std::uint32_t>(), std::declval<bool>()));
299 
300 template <typename T>
301 inline constexpr bool func_has_taylor_diff_f128_v
302     = std::is_same_v<detected_t<func_taylor_diff_f128_t, T>, llvm::Value *>;
303 
304 #endif
305 
306 template <typename T>
307 using func_taylor_c_diff_func_dbl_t
308     = decltype(std::declval<std::add_lvalue_reference_t<const T>>().taylor_c_diff_func_dbl(
309         std::declval<llvm_state &>(), std::declval<std::uint32_t>(), std::declval<std::uint32_t>(),
310         std::declval<bool>()));
311 
312 template <typename T>
313 inline constexpr bool func_has_taylor_c_diff_func_dbl_v
314     = std::is_same_v<detected_t<func_taylor_c_diff_func_dbl_t, T>, llvm::Function *>;
315 
316 template <typename T>
317 using func_taylor_c_diff_func_ldbl_t
318     = decltype(std::declval<std::add_lvalue_reference_t<const T>>().taylor_c_diff_func_ldbl(
319         std::declval<llvm_state &>(), std::declval<std::uint32_t>(), std::declval<std::uint32_t>(),
320         std::declval<bool>()));
321 
322 template <typename T>
323 inline constexpr bool func_has_taylor_c_diff_func_ldbl_v
324     = std::is_same_v<detected_t<func_taylor_c_diff_func_ldbl_t, T>, llvm::Function *>;
325 
326 #if defined(HEYOKA_HAVE_REAL128)
327 
328 template <typename T>
329 using func_taylor_c_diff_func_f128_t
330     = decltype(std::declval<std::add_lvalue_reference_t<const T>>().taylor_c_diff_func_f128(
331         std::declval<llvm_state &>(), std::declval<std::uint32_t>(), std::declval<std::uint32_t>(),
332         std::declval<bool>()));
333 
334 template <typename T>
335 inline constexpr bool func_has_taylor_c_diff_func_f128_v
336     = std::is_same_v<detected_t<func_taylor_c_diff_func_f128_t, T>, llvm::Function *>;
337 
338 #endif
339 
340 HEYOKA_DLL_PUBLIC void func_default_to_stream_impl(std::ostream &, const func_base &);
341 
342 template <typename T>
343 struct HEYOKA_DLL_PUBLIC_INLINE_CLASS func_inner final : func_inner_base {
344     T m_value;
345 
346     // We just need the def ctor, delete everything else.
347     func_inner() = default;
348     func_inner(const func_inner &) = delete;
349     func_inner(func_inner &&) = delete;
350     func_inner &operator=(const func_inner &) = delete;
351     func_inner &operator=(func_inner &&) = delete;
352 
353     // Constructors from T (copy and move variants).
func_innerheyoka::detail::func_inner354     explicit func_inner(const T &x) : m_value(x) {}
func_innerheyoka::detail::func_inner355     explicit func_inner(T &&x) : m_value(std::move(x)) {}
356 
357     // The clone function.
cloneheyoka::detail::func_inner358     std::unique_ptr<func_inner_base> clone() const final
359     {
360         return std::make_unique<func_inner>(m_value);
361     }
362 
363     // Get the type at runtime.
get_type_indexheyoka::detail::func_inner364     std::type_index get_type_index() const final
365     {
366         return typeid(T);
367     }
368     // Raw getters for the internal instance.
get_ptrheyoka::detail::func_inner369     const void *get_ptr() const final
370     {
371         return &m_value;
372     }
get_ptrheyoka::detail::func_inner373     void *get_ptr() final
374     {
375         return &m_value;
376     }
377 
get_nameheyoka::detail::func_inner378     const std::string &get_name() const final
379     {
380         // NOTE: make sure we are invoking the member functions
381         // from func_base (these functions could have been overriden
382         // in the derived class).
383         return static_cast<const func_base *>(&m_value)->get_name();
384     }
385 
to_streamheyoka::detail::func_inner386     void to_stream(std::ostream &os) const final
387     {
388         if constexpr (func_has_to_stream_v<T>) {
389             m_value.to_stream(os);
390         } else {
391             func_default_to_stream_impl(os, static_cast<const func_base &>(m_value));
392         }
393     }
394 
extra_equal_toheyoka::detail::func_inner395     bool extra_equal_to(const func &f) const final
396     {
397         if constexpr (func_has_extra_equal_to_v<T>) {
398             return m_value.extra_equal_to(f);
399         } else {
400             return true;
401         }
402     }
403 
extra_hashheyoka::detail::func_inner404     std::size_t extra_hash() const final
405     {
406         if constexpr (func_has_extra_hash_v<T>) {
407             return m_value.extra_hash();
408         } else {
409             return 0;
410         }
411     }
412 
argsheyoka::detail::func_inner413     const std::vector<expression> &args() const final
414     {
415         return static_cast<const func_base *>(&m_value)->args();
416     }
get_mutable_args_itheyoka::detail::func_inner417     std::pair<std::vector<expression>::iterator, std::vector<expression>::iterator> get_mutable_args_it() final
418     {
419         return static_cast<func_base *>(&m_value)->get_mutable_args_it();
420     }
421 
422     // codegen.
codegen_dblheyoka::detail::func_inner423     llvm::Value *codegen_dbl(llvm_state &s, const std::vector<llvm::Value *> &v) const final
424     {
425         if constexpr (func_has_codegen_dbl_v<T>) {
426             return m_value.codegen_dbl(s, v);
427         } else {
428             throw not_implemented_error("double codegen is not implemented for the function '" + get_name() + "'");
429         }
430     }
codegen_ldblheyoka::detail::func_inner431     llvm::Value *codegen_ldbl(llvm_state &s, const std::vector<llvm::Value *> &v) const final
432     {
433         if constexpr (func_has_codegen_ldbl_v<T>) {
434             return m_value.codegen_ldbl(s, v);
435         } else {
436             throw not_implemented_error("long double codegen is not implemented for the function '" + get_name() + "'");
437         }
438     }
439 #if defined(HEYOKA_HAVE_REAL128)
codegen_f128heyoka::detail::func_inner440     llvm::Value *codegen_f128(llvm_state &s, const std::vector<llvm::Value *> &v) const final
441     {
442         if constexpr (func_has_codegen_f128_v<T>) {
443             return m_value.codegen_f128(s, v);
444         } else {
445             throw not_implemented_error("float128 codegen is not implemented for the function '" + get_name() + "'");
446         }
447     }
448 #endif
449 
450     // diff.
has_diff_varheyoka::detail::func_inner451     bool has_diff_var() const final
452     {
453         return func_has_diff_var_v<T>;
454     }
455     expression diff(std::unordered_map<const void *, expression> &, const std::string &) const final;
has_diff_parheyoka::detail::func_inner456     bool has_diff_par() const final
457     {
458         return func_has_diff_par_v<T>;
459     }
460     expression diff(std::unordered_map<const void *, expression> &, const param &) const final;
461 
462     // gradient.
has_gradientheyoka::detail::func_inner463     bool has_gradient() const final
464     {
465         return func_has_gradient_v<T>;
466     }
gradientheyoka::detail::func_inner467     std::vector<expression> gradient() const final
468     {
469         if constexpr (func_has_gradient_v<T>) {
470             return m_value.gradient();
471         }
472 
473         // LCOV_EXCL_START
474         assert(false);
475         throw;
476         // LCOV_EXCL_STOP
477     }
478 
479     // eval.
eval_dblheyoka::detail::func_inner480     double eval_dbl(const std::unordered_map<std::string, double> &m, const std::vector<double> &pars) const final
481     {
482         if constexpr (func_has_eval_dbl_v<T>) {
483             return m_value.eval_dbl(m, pars);
484         } else {
485             throw not_implemented_error("double eval is not implemented for the function '" + get_name() + "'");
486         }
487     }
eval_ldblheyoka::detail::func_inner488     long double eval_ldbl(const std::unordered_map<std::string, long double> &m,
489                           const std::vector<long double> &pars) const final
490     {
491         if constexpr (func_has_eval_ldbl_v<T>) {
492             return m_value.eval_ldbl(m, pars);
493         } else {
494             throw not_implemented_error("long double eval is not implemented for the function '" + get_name() + "'");
495         }
496     }
497 #if defined(HEYOKA_HAVE_REAL128)
eval_f128heyoka::detail::func_inner498     mppp::real128 eval_f128(const std::unordered_map<std::string, mppp::real128> &m,
499                             const std::vector<mppp::real128> &pars) const final
500     {
501         if constexpr (func_has_eval_f128_v<T>) {
502             return m_value.eval_f128(m, pars);
503         } else {
504             throw not_implemented_error("mppp::real128 eval is not implemented for the function '" + get_name() + "'");
505         }
506     }
507 #endif
eval_batch_dblheyoka::detail::func_inner508     void eval_batch_dbl(std::vector<double> &out, const std::unordered_map<std::string, std::vector<double>> &m,
509                         const std::vector<double> &pars) const final
510     {
511         if constexpr (func_has_eval_batch_dbl_v<T>) {
512             m_value.eval_batch_dbl(out, m, pars);
513         } else {
514             throw not_implemented_error("double batch eval is not implemented for the function '" + get_name() + "'");
515         }
516     }
eval_num_dblheyoka::detail::func_inner517     double eval_num_dbl(const std::vector<double> &v) const final
518     {
519         if constexpr (func_has_eval_num_dbl_v<T>) {
520             return m_value.eval_num_dbl(v);
521         } else {
522             throw not_implemented_error("double numerical eval is not implemented for the function '" + get_name()
523                                         + "'");
524         }
525     }
deval_num_dblheyoka::detail::func_inner526     double deval_num_dbl(const std::vector<double> &v, std::vector<double>::size_type i) const final
527     {
528         if constexpr (func_has_deval_num_dbl_v<T>) {
529             return m_value.deval_num_dbl(v, i);
530         } else {
531             throw not_implemented_error("double numerical eval of the derivative is not implemented for the function '"
532                                         + get_name() + "'");
533         }
534     }
535 
536     // Taylor.
taylor_decomposeheyoka::detail::func_inner537     taylor_dc_t::size_type taylor_decompose(taylor_dc_t &dc) && final
538     {
539         if constexpr (func_has_taylor_decompose_v<T>) {
540             return std::move(m_value).taylor_decompose(dc);
541         }
542 
543         // LCOV_EXCL_START
544         assert(false);
545         throw;
546         // LCOV_EXCL_STOP
547     }
has_taylor_decomposeheyoka::detail::func_inner548     bool has_taylor_decompose() const final
549     {
550         return func_has_taylor_decompose_v<T>;
551     }
taylor_diff_dblheyoka::detail::func_inner552     llvm::Value *taylor_diff_dbl(llvm_state &s, const std::vector<std::uint32_t> &deps,
553                                  const std::vector<llvm::Value *> &arr, llvm::Value *par_ptr, llvm::Value *time_ptr,
554                                  std::uint32_t n_uvars, std::uint32_t order, std::uint32_t idx,
555                                  std::uint32_t batch_size, bool high_accuracy) const final
556     {
557         if constexpr (func_has_taylor_diff_dbl_v<T>) {
558             return m_value.taylor_diff_dbl(s, deps, arr, par_ptr, time_ptr, n_uvars, order, idx, batch_size,
559                                            high_accuracy);
560         } else {
561             throw not_implemented_error("double Taylor diff is not implemented for the function '" + get_name() + "'");
562         }
563     }
taylor_diff_ldblheyoka::detail::func_inner564     llvm::Value *taylor_diff_ldbl(llvm_state &s, const std::vector<std::uint32_t> &deps,
565                                   const std::vector<llvm::Value *> &arr, llvm::Value *par_ptr, llvm::Value *time_ptr,
566                                   std::uint32_t n_uvars, std::uint32_t order, std::uint32_t idx,
567                                   std::uint32_t batch_size, bool high_accuracy) const final
568     {
569         if constexpr (func_has_taylor_diff_ldbl_v<T>) {
570             return m_value.taylor_diff_ldbl(s, deps, arr, par_ptr, time_ptr, n_uvars, order, idx, batch_size,
571                                             high_accuracy);
572         } else {
573             throw not_implemented_error("long double Taylor diff is not implemented for the function '" + get_name()
574                                         + "'");
575         }
576     }
577 #if defined(HEYOKA_HAVE_REAL128)
taylor_diff_f128heyoka::detail::func_inner578     llvm::Value *taylor_diff_f128(llvm_state &s, const std::vector<std::uint32_t> &deps,
579                                   const std::vector<llvm::Value *> &arr, llvm::Value *par_ptr, llvm::Value *time_ptr,
580                                   std::uint32_t n_uvars, std::uint32_t order, std::uint32_t idx,
581                                   std::uint32_t batch_size, bool high_accuracy) const final
582     {
583         if constexpr (func_has_taylor_diff_f128_v<T>) {
584             return m_value.taylor_diff_f128(s, deps, arr, par_ptr, time_ptr, n_uvars, order, idx, batch_size,
585                                             high_accuracy);
586         } else {
587             throw not_implemented_error("float128 Taylor diff is not implemented for the function '" + get_name()
588                                         + "'");
589         }
590     }
591 #endif
taylor_c_diff_func_dblheyoka::detail::func_inner592     llvm::Function *taylor_c_diff_func_dbl(llvm_state &s, std::uint32_t n_uvars, std::uint32_t batch_size,
593                                            bool high_accuracy) const final
594     {
595         if constexpr (func_has_taylor_c_diff_func_dbl_v<T>) {
596             return m_value.taylor_c_diff_func_dbl(s, n_uvars, batch_size, high_accuracy);
597         } else {
598             throw not_implemented_error("double Taylor diff in compact mode is not implemented for the function '"
599                                         + get_name() + "'");
600         }
601     }
taylor_c_diff_func_ldblheyoka::detail::func_inner602     llvm::Function *taylor_c_diff_func_ldbl(llvm_state &s, std::uint32_t n_uvars, std::uint32_t batch_size,
603                                             bool high_accuracy) const final
604     {
605         if constexpr (func_has_taylor_c_diff_func_ldbl_v<T>) {
606             return m_value.taylor_c_diff_func_ldbl(s, n_uvars, batch_size, high_accuracy);
607         } else {
608             throw not_implemented_error("long double Taylor diff in compact mode is not implemented for the function '"
609                                         + get_name() + "'");
610         }
611     }
612 #if defined(HEYOKA_HAVE_REAL128)
taylor_c_diff_func_f128heyoka::detail::func_inner613     llvm::Function *taylor_c_diff_func_f128(llvm_state &s, std::uint32_t n_uvars, std::uint32_t batch_size,
614                                             bool high_accuracy) const final
615     {
616         if constexpr (func_has_taylor_c_diff_func_f128_v<T>) {
617             return m_value.taylor_c_diff_func_f128(s, n_uvars, batch_size, high_accuracy);
618         } else {
619             throw not_implemented_error("float128 Taylor diff in compact mode is not implemented for the function '"
620                                         + get_name() + "'");
621         }
622     }
623 #endif
624 
625 private:
626     // Serialization.
627     friend class boost::serialization::access;
628     template <typename Archive>
serializeheyoka::detail::func_inner629     void serialize(Archive &ar, unsigned)
630     {
631         ar &boost::serialization::base_object<func_inner_base>(*this);
632         ar &m_value;
633     }
634 };
635 
636 template <typename T>
637 using is_func = std::conjunction<std::is_same<T, uncvref_t<T>>, std::is_default_constructible<T>,
638                                  std::is_copy_constructible<T>, std::is_move_constructible<T>, std::is_destructible<T>,
639                                  // https://en.cppreference.com/w/cpp/concepts/derived_from
640                                  // NOTE: use add_pointer/add_cv in order to avoid
641                                  // issues if invoked with problematic types (e.g., void).
642                                  std::is_base_of<func_base, T>,
643                                  std::is_convertible<std::add_pointer_t<std::add_cv_t<T>>, const volatile func_base *>>;
644 
645 } // namespace detail
646 
647 HEYOKA_DLL_PUBLIC void swap(func &, func &) noexcept;
648 
649 HEYOKA_DLL_PUBLIC std::ostream &operator<<(std::ostream &, const func &);
650 
651 HEYOKA_DLL_PUBLIC std::size_t hash(const func &);
652 
653 HEYOKA_DLL_PUBLIC bool operator==(const func &, const func &);
654 HEYOKA_DLL_PUBLIC bool operator!=(const func &, const func &);
655 
656 class HEYOKA_DLL_PUBLIC func
657 {
658     friend HEYOKA_DLL_PUBLIC void swap(func &, func &) noexcept;
659     friend HEYOKA_DLL_PUBLIC std::ostream &operator<<(std::ostream &, const func &);
660     friend HEYOKA_DLL_PUBLIC std::size_t hash(const func &);
661     friend HEYOKA_DLL_PUBLIC bool operator==(const func &, const func &);
662 
663     // Pointer to the inner base.
664     std::shared_ptr<detail::func_inner_base> m_ptr;
665 
666     // Serialization.
667     friend class boost::serialization::access;
668     template <typename Archive>
serialize(Archive & ar,unsigned version)669     void serialize(Archive &ar, unsigned version)
670     {
671         // LCOV_EXCL_START
672         if (version == 0u) {
673             throw std::invalid_argument("Cannot load a function instance from an older archive");
674         }
675         // LCOV_EXCL_STOP
676 
677         ar &m_ptr;
678     }
679 
680     // Just two small helpers to make sure that whenever we require
681     // access to the pointer it actually points to something.
682     const detail::func_inner_base *ptr() const;
683     detail::func_inner_base *ptr();
684 
685     // Private constructor used in the copy() function.
686     HEYOKA_DLL_LOCAL explicit func(std::unique_ptr<detail::func_inner_base>);
687 
688     // Private helper to extract and check the gradient in the
689     // diff() implementations.
690     HEYOKA_DLL_LOCAL std::vector<expression> fetch_gradient(const std::string &) const;
691 
692     template <typename T>
693     using generic_ctor_enabler
694         = std::enable_if_t<std::conjunction_v<std::negation<std::is_same<func, detail::uncvref_t<T>>>,
695                                               detail::is_func<detail::uncvref_t<T>>>,
696                            int>;
697 
698 public:
699     func();
700 
701     template <typename T, generic_ctor_enabler<T &&> = 0>
func(T && x)702     explicit func(T &&x) : m_ptr(std::make_unique<detail::func_inner<detail::uncvref_t<T>>>(std::forward<T>(x)))
703     {
704     }
705 
706     func(const func &);
707     func(func &&) noexcept;
708 
709     func &operator=(const func &);
710     func &operator=(func &&) noexcept;
711 
712     ~func();
713 
714     // NOTE: this creates a new func containing
715     // a copy of the inner object: this means that
716     // the function arguments are shallow-copied and
717     // NOT deep-copied.
718     func copy() const;
719 
720     // NOTE: like in pagmo, this may fail if invoked
721     // from different DLLs in certain situations (e.g.,
722     // Python bindings on OSX). I don't
723     // think this is currently an interesting use case
724     // for heyoka (as we don't provide a way of implementing
725     // new functions in Python), but, if it becomes a problem
726     // in the future, we can solve this in the same way as
727     // in pagmo.
728     template <typename T>
extract() const729     const T *extract() const noexcept
730     {
731         auto p = dynamic_cast<const detail::func_inner<T> *>(ptr());
732         return p == nullptr ? nullptr : &(p->m_value);
733     }
734     template <typename T>
extract()735     T *extract() noexcept
736     {
737         auto p = dynamic_cast<detail::func_inner<T> *>(ptr());
738         return p == nullptr ? nullptr : &(p->m_value);
739     }
740 
741     std::type_index get_type_index() const;
742     const void *get_ptr() const;
743     void *get_ptr();
744 
745     const std::string &get_name() const;
746 
747     const std::vector<expression> &args() const;
748     std::pair<std::vector<expression>::iterator, std::vector<expression>::iterator> get_mutable_args_it();
749 
750     llvm::Value *codegen_dbl(llvm_state &, const std::vector<llvm::Value *> &) const;
751     llvm::Value *codegen_ldbl(llvm_state &, const std::vector<llvm::Value *> &) const;
752 #if defined(HEYOKA_HAVE_REAL128)
753     llvm::Value *codegen_f128(llvm_state &, const std::vector<llvm::Value *> &) const;
754 #endif
755 
756     expression diff(std::unordered_map<const void *, expression> &, const std::string &) const;
757     expression diff(std::unordered_map<const void *, expression> &, const param &) const;
758 
759     double eval_dbl(const std::unordered_map<std::string, double> &, const std::vector<double> &) const;
760     long double eval_ldbl(const std::unordered_map<std::string, long double> &, const std::vector<long double> &) const;
761 #if defined(HEYOKA_HAVE_REAL128)
762     mppp::real128 eval_f128(const std::unordered_map<std::string, mppp::real128> &,
763                             const std::vector<mppp::real128> &) const;
764 #endif
765 
766     void eval_batch_dbl(std::vector<double> &, const std::unordered_map<std::string, std::vector<double>> &,
767                         const std::vector<double> &) const;
768     double eval_num_dbl(const std::vector<double> &) const;
769     double deval_num_dbl(const std::vector<double> &, std::vector<double>::size_type) const;
770 
771     taylor_dc_t::size_type taylor_decompose(std::unordered_map<const void *, taylor_dc_t::size_type> &,
772                                             taylor_dc_t &) const;
773     llvm::Value *taylor_diff_dbl(llvm_state &, const std::vector<std::uint32_t> &, const std::vector<llvm::Value *> &,
774                                  llvm::Value *, llvm::Value *, std::uint32_t, std::uint32_t, std::uint32_t,
775                                  std::uint32_t, bool) const;
776     llvm::Value *taylor_diff_ldbl(llvm_state &, const std::vector<std::uint32_t> &, const std::vector<llvm::Value *> &,
777                                   llvm::Value *, llvm::Value *, std::uint32_t, std::uint32_t, std::uint32_t,
778                                   std::uint32_t, bool) const;
779 #if defined(HEYOKA_HAVE_REAL128)
780     llvm::Value *taylor_diff_f128(llvm_state &, const std::vector<std::uint32_t> &, const std::vector<llvm::Value *> &,
781                                   llvm::Value *, llvm::Value *, std::uint32_t, std::uint32_t, std::uint32_t,
782                                   std::uint32_t, bool) const;
783 #endif
784     llvm::Function *taylor_c_diff_func_dbl(llvm_state &, std::uint32_t, std::uint32_t, bool) const;
785     llvm::Function *taylor_c_diff_func_ldbl(llvm_state &, std::uint32_t, std::uint32_t, bool) const;
786 #if defined(HEYOKA_HAVE_REAL128)
787     llvm::Function *taylor_c_diff_func_f128(llvm_state &, std::uint32_t, std::uint32_t, bool) const;
788 #endif
789 };
790 
791 HEYOKA_DLL_PUBLIC double eval_dbl(const func &, const std::unordered_map<std::string, double> &,
792                                   const std::vector<double> &);
793 HEYOKA_DLL_PUBLIC long double eval_ldbl(const func &, const std::unordered_map<std::string, long double> &,
794                                         const std::vector<long double> &);
795 #if defined(HEYOKA_HAVE_REAL128)
796 HEYOKA_DLL_PUBLIC mppp::real128 eval_f128(const func &, const std::unordered_map<std::string, mppp::real128> &,
797                                           const std::vector<mppp::real128> &);
798 #endif
799 
800 HEYOKA_DLL_PUBLIC void eval_batch_dbl(std::vector<double> &, const func &,
801                                       const std::unordered_map<std::string, std::vector<double>> &,
802                                       const std::vector<double> &);
803 HEYOKA_DLL_PUBLIC double eval_num_dbl(const func &, const std::vector<double> &);
804 HEYOKA_DLL_PUBLIC double deval_num_dbl(const func &, const std::vector<double> &, std::vector<double>::size_type);
805 
806 HEYOKA_DLL_PUBLIC void update_connections(std::vector<std::vector<std::size_t>> &, const func &, std::size_t &);
807 HEYOKA_DLL_PUBLIC void update_node_values_dbl(std::vector<double> &, const func &,
808                                               const std::unordered_map<std::string, double> &,
809                                               const std::vector<std::vector<std::size_t>> &, std::size_t &);
810 HEYOKA_DLL_PUBLIC void update_grad_dbl(std::unordered_map<std::string, double> &, const func &,
811                                        const std::unordered_map<std::string, double> &, const std::vector<double> &,
812                                        const std::vector<std::vector<std::size_t>> &, std::size_t &, double);
813 
814 namespace detail
815 {
816 
817 // Helper to run the codegen of a function-like object with the arguments
818 // represented as a vector of LLVM values.
819 template <typename T, typename F>
codegen_from_values(llvm_state & s,const F & f,const std::vector<llvm::Value * > & args_v)820 inline llvm::Value *codegen_from_values(llvm_state &s, const F &f, const std::vector<llvm::Value *> &args_v)
821 {
822     if constexpr (std::is_same_v<T, double>) {
823         return f.codegen_dbl(s, args_v);
824     } else if constexpr (std::is_same_v<T, long double>) {
825         return f.codegen_ldbl(s, args_v);
826 #if defined(HEYOKA_HAVE_REAL128)
827     } else if constexpr (std::is_same_v<T, mppp::real128>) {
828         return f.codegen_f128(s, args_v);
829 #endif
830     } else {
831         static_assert(detail::always_false_v<T>, "Unhandled type.");
832     }
833 }
834 
835 } // namespace detail
836 
837 } // namespace heyoka
838 
839 // Current archive version is 1.
840 BOOST_CLASS_VERSION(heyoka::func, 1)
841 
842 // Macros for the registration of s11n for concrete functions.
843 #define HEYOKA_S11N_FUNC_EXPORT_KEY(f) BOOST_CLASS_EXPORT_KEY(heyoka::detail::func_inner<f>)
844 
845 #define HEYOKA_S11N_FUNC_EXPORT_IMPLEMENT(f) BOOST_CLASS_EXPORT_IMPLEMENT(heyoka::detail::func_inner<f>)
846 
847 #define HEYOKA_S11N_FUNC_EXPORT(f)                                                                                     \
848     HEYOKA_S11N_FUNC_EXPORT_KEY(f)                                                                                     \
849     HEYOKA_S11N_FUNC_EXPORT_IMPLEMENT(f)
850 
851 #endif
852