1 // Copyright 2020, 2021 Francesco Biscani (bluescarni@gmail.com), Dario Izzo (dario.izzo@gmail.com)
2 //
3 // This file is part of the heyoka library.
4 //
5 // This Source Code Form is subject to the terms of the Mozilla
6 // Public License v. 2.0. If a copy of the MPL was not distributed
7 // with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
8
9 #ifndef HEYOKA_FUNC_HPP
10 #define HEYOKA_FUNC_HPP
11
12 #include <heyoka/config.hpp>
13
14 #include <cstddef>
15 #include <cstdint>
16 #include <memory>
17 #include <ostream>
18 #include <stdexcept>
19 #include <string>
20 #include <type_traits>
21 #include <typeindex>
22 #include <typeinfo>
23 #include <unordered_map>
24 #include <unordered_set>
25 #include <utility>
26 #include <vector>
27
28 #if defined(HEYOKA_HAVE_REAL128)
29
30 #include <mp++/real128.hpp>
31
32 #endif
33
34 #include <heyoka/detail/fwd_decl.hpp>
35 #include <heyoka/detail/llvm_fwd.hpp>
36 #include <heyoka/detail/type_traits.hpp>
37 #include <heyoka/detail/visibility.hpp>
38 #include <heyoka/exceptions.hpp>
39 #include <heyoka/s11n.hpp>
40
41 namespace heyoka
42 {
43
44 class HEYOKA_DLL_PUBLIC func_base
45 {
46 std::string m_name;
47 std::vector<expression> m_args;
48
49 // Serialization.
50 friend class boost::serialization::access;
51 template <typename Archive>
serialize(Archive & ar,unsigned)52 void serialize(Archive &ar, unsigned)
53 {
54 ar &m_name;
55 ar &m_args;
56 }
57
58 public:
59 explicit func_base(std::string, std::vector<expression>);
60
61 func_base(const func_base &);
62 func_base(func_base &&) noexcept;
63
64 func_base &operator=(const func_base &);
65 func_base &operator=(func_base &&) noexcept;
66
67 ~func_base();
68
69 const std::string &get_name() const;
70 const std::vector<expression> &args() const;
71 std::pair<std::vector<expression>::iterator, std::vector<expression>::iterator> get_mutable_args_it();
72 };
73
74 namespace detail
75 {
76
77 struct HEYOKA_DLL_PUBLIC func_inner_base {
78 virtual ~func_inner_base();
79 virtual std::unique_ptr<func_inner_base> clone() const = 0;
80
81 virtual std::type_index get_type_index() const = 0;
82 virtual const void *get_ptr() const = 0;
83 virtual void *get_ptr() = 0;
84
85 virtual const std::string &get_name() const = 0;
86
87 virtual void to_stream(std::ostream &) const = 0;
88
89 virtual bool extra_equal_to(const func &) const = 0;
90
91 virtual std::size_t extra_hash() const = 0;
92
93 virtual const std::vector<expression> &args() const = 0;
94 virtual std::pair<std::vector<expression>::iterator, std::vector<expression>::iterator> get_mutable_args_it() = 0;
95
96 virtual llvm::Value *codegen_dbl(llvm_state &, const std::vector<llvm::Value *> &) const = 0;
97 virtual llvm::Value *codegen_ldbl(llvm_state &, const std::vector<llvm::Value *> &) const = 0;
98 #if defined(HEYOKA_HAVE_REAL128)
99 virtual llvm::Value *codegen_f128(llvm_state &, const std::vector<llvm::Value *> &) const = 0;
100 #endif
101
102 virtual bool has_diff_var() const = 0;
103 virtual expression diff(std::unordered_map<const void *, expression> &, const std::string &) const = 0;
104 virtual bool has_diff_par() const = 0;
105 virtual expression diff(std::unordered_map<const void *, expression> &, const param &) const = 0;
106 virtual bool has_gradient() const = 0;
107 virtual std::vector<expression> gradient() const = 0;
108
109 virtual double eval_dbl(const std::unordered_map<std::string, double> &, const std::vector<double> &) const = 0;
110 virtual long double eval_ldbl(const std::unordered_map<std::string, long double> &,
111 const std::vector<long double> &) const = 0;
112 #if defined(HEYOKA_HAVE_REAL128)
113 virtual mppp::real128 eval_f128(const std::unordered_map<std::string, mppp::real128> &,
114 const std::vector<mppp::real128> &) const = 0;
115 #endif
116
117 virtual void eval_batch_dbl(std::vector<double> &, const std::unordered_map<std::string, std::vector<double>> &,
118 const std::vector<double> &) const = 0;
119 virtual double eval_num_dbl(const std::vector<double> &) const = 0;
120 virtual double deval_num_dbl(const std::vector<double> &, std::vector<double>::size_type) const = 0;
121
122 virtual taylor_dc_t::size_type taylor_decompose(taylor_dc_t &) && = 0;
123 virtual bool has_taylor_decompose() const = 0;
124 virtual llvm::Value *taylor_diff_dbl(llvm_state &, const std::vector<std::uint32_t> &,
125 const std::vector<llvm::Value *> &, llvm::Value *, llvm::Value *,
126 std::uint32_t, std::uint32_t, std::uint32_t, std::uint32_t, bool) const = 0;
127 virtual llvm::Value *taylor_diff_ldbl(llvm_state &, const std::vector<std::uint32_t> &,
128 const std::vector<llvm::Value *> &, llvm::Value *, llvm::Value *,
129 std::uint32_t, std::uint32_t, std::uint32_t, std::uint32_t, bool) const = 0;
130 #if defined(HEYOKA_HAVE_REAL128)
131 virtual llvm::Value *taylor_diff_f128(llvm_state &, const std::vector<std::uint32_t> &,
132 const std::vector<llvm::Value *> &, llvm::Value *, llvm::Value *,
133 std::uint32_t, std::uint32_t, std::uint32_t, std::uint32_t, bool) const = 0;
134 #endif
135 virtual llvm::Function *taylor_c_diff_func_dbl(llvm_state &, std::uint32_t, std::uint32_t, bool) const = 0;
136 virtual llvm::Function *taylor_c_diff_func_ldbl(llvm_state &, std::uint32_t, std::uint32_t, bool) const = 0;
137 #if defined(HEYOKA_HAVE_REAL128)
138 virtual llvm::Function *taylor_c_diff_func_f128(llvm_state &, std::uint32_t, std::uint32_t, bool) const = 0;
139 #endif
140
141 private:
142 // Serialization.
143 friend class boost::serialization::access;
144 template <typename Archive>
serializeheyoka::detail::func_inner_base145 void serialize(Archive &, unsigned)
146 {
147 }
148 };
149
150 template <typename T>
151 using func_to_stream_t
152 = decltype(std::declval<std::add_lvalue_reference_t<const T>>().to_stream(std::declval<std::ostream &>()));
153
154 template <typename T>
155 inline constexpr bool func_has_to_stream_v = std::is_same_v<detected_t<func_to_stream_t, T>, void>;
156
157 template <typename T>
158 using func_extra_equal_to_t
159 = decltype(std::declval<std::add_lvalue_reference_t<const T>>().extra_equal_to(std::declval<const func &>()));
160
161 template <typename T>
162 inline constexpr bool func_has_extra_equal_to_v = std::is_same_v<detected_t<func_extra_equal_to_t, T>, bool>;
163
164 template <typename T>
165 using func_extra_hash_t = decltype(std::declval<std::add_lvalue_reference_t<const T>>().extra_hash());
166
167 template <typename T>
168 inline constexpr bool func_has_extra_hash_v = std::is_same_v<detected_t<func_extra_hash_t, T>, std::size_t>;
169
170 template <typename T>
171 using func_codegen_dbl_t = decltype(std::declval<std::add_lvalue_reference_t<const T>>().codegen_dbl(
172 std::declval<llvm_state &>(), std::declval<const std::vector<llvm::Value *> &>()));
173
174 template <typename T>
175 inline constexpr bool func_has_codegen_dbl_v = std::is_same_v<detected_t<func_codegen_dbl_t, T>, llvm::Value *>;
176
177 template <typename T>
178 using func_codegen_ldbl_t = decltype(std::declval<std::add_lvalue_reference_t<const T>>().codegen_ldbl(
179 std::declval<llvm_state &>(), std::declval<const std::vector<llvm::Value *> &>()));
180
181 template <typename T>
182 inline constexpr bool func_has_codegen_ldbl_v = std::is_same_v<detected_t<func_codegen_ldbl_t, T>, llvm::Value *>;
183
184 #if defined(HEYOKA_HAVE_REAL128)
185
186 template <typename T>
187 using func_codegen_f128_t = decltype(std::declval<std::add_lvalue_reference_t<const T>>().codegen_f128(
188 std::declval<llvm_state &>(), std::declval<const std::vector<llvm::Value *> &>()));
189
190 template <typename T>
191 inline constexpr bool func_has_codegen_f128_v = std::is_same_v<detected_t<func_codegen_f128_t, T>, llvm::Value *>;
192
193 #endif
194
195 template <typename T>
196 using func_diff_var_t = decltype(std::declval<std::add_lvalue_reference_t<const T>>().diff(
197 std::declval<std::unordered_map<const void *, expression> &>(), std::declval<const std::string &>()));
198
199 template <typename T>
200 inline constexpr bool func_has_diff_var_v = std::is_same_v<detected_t<func_diff_var_t, T>, expression>;
201
202 template <typename T>
203 using func_diff_par_t = decltype(std::declval<std::add_lvalue_reference_t<const T>>().diff(
204 std::declval<std::unordered_map<const void *, expression> &>(), std::declval<const param &>()));
205
206 template <typename T>
207 inline constexpr bool func_has_diff_par_v = std::is_same_v<detected_t<func_diff_par_t, T>, expression>;
208
209 template <typename T>
210 using func_gradient_t = decltype(std::declval<std::add_lvalue_reference_t<const T>>().gradient());
211
212 template <typename T>
213 inline constexpr bool func_has_gradient_v = std::is_same_v<detected_t<func_gradient_t, T>, std::vector<expression>>;
214
215 template <typename T>
216 using func_eval_dbl_t = decltype(std::declval<std::add_lvalue_reference_t<const T>>().eval_dbl(
217 std::declval<const std::unordered_map<std::string, double> &>(), std::declval<const std::vector<double> &>()));
218
219 template <typename T>
220 inline constexpr bool func_has_eval_dbl_v = std::is_same_v<detected_t<func_eval_dbl_t, T>, double>;
221
222 template <typename T>
223 using func_eval_ldbl_t = decltype(std::declval<std::add_lvalue_reference_t<const T>>().eval_ldbl(
224 std::declval<const std::unordered_map<std::string, long double> &>(),
225 std::declval<const std::vector<long double> &>()));
226
227 template <typename T>
228 inline constexpr bool func_has_eval_ldbl_v = std::is_same_v<detected_t<func_eval_ldbl_t, T>, long double>;
229
230 #if defined(HEYOKA_HAVE_REAL128)
231 template <typename T>
232 using func_eval_f128_t = decltype(std::declval<std::add_lvalue_reference_t<const T>>().eval_f128(
233 std::declval<const std::unordered_map<std::string, mppp::real128> &>(),
234 std::declval<const std::vector<mppp::real128> &>()));
235
236 template <typename T>
237 inline constexpr bool func_has_eval_f128_v = std::is_same_v<detected_t<func_eval_f128_t, T>, mppp::real128>;
238 #endif
239
240 template <typename T>
241 using func_eval_batch_dbl_t = decltype(std::declval<std::add_lvalue_reference_t<const T>>().eval_batch_dbl(
242 std::declval<std::vector<double> &>(), std::declval<const std::unordered_map<std::string, std::vector<double>> &>(),
243 std::declval<const std::vector<double> &>()));
244
245 template <typename T>
246 inline constexpr bool func_has_eval_batch_dbl_v = std::is_same_v<detected_t<func_eval_batch_dbl_t, T>, void>;
247
248 template <typename T>
249 using func_eval_num_dbl_t = decltype(std::declval<std::add_lvalue_reference_t<const T>>().eval_num_dbl(
250 std::declval<const std::vector<double> &>()));
251
252 template <typename T>
253 inline constexpr bool func_has_eval_num_dbl_v = std::is_same_v<detected_t<func_eval_num_dbl_t, T>, double>;
254
255 template <typename T>
256 using func_deval_num_dbl_t = decltype(std::declval<std::add_lvalue_reference_t<const T>>().deval_num_dbl(
257 std::declval<const std::vector<double> &>(), std::declval<std::vector<double>::size_type>()));
258
259 template <typename T>
260 inline constexpr bool func_has_deval_num_dbl_v = std::is_same_v<detected_t<func_deval_num_dbl_t, T>, double>;
261
262 template <typename T>
263 using func_taylor_decompose_t
264 = decltype(std::declval<std::add_rvalue_reference_t<T>>().taylor_decompose(std::declval<taylor_dc_t &>()));
265
266 template <typename T>
267 inline constexpr bool func_has_taylor_decompose_v
268 = std::is_same_v<detected_t<func_taylor_decompose_t, T>, taylor_dc_t::size_type>;
269
270 template <typename T>
271 using func_taylor_diff_dbl_t = decltype(std::declval<std::add_lvalue_reference_t<const T>>().taylor_diff_dbl(
272 std::declval<llvm_state &>(), std::declval<const std::vector<std::uint32_t> &>(),
273 std::declval<const std::vector<llvm::Value *> &>(), std::declval<llvm::Value *>(), std::declval<llvm::Value *>(),
274 std::declval<std::uint32_t>(), std::declval<std::uint32_t>(), std::declval<std::uint32_t>(),
275 std::declval<std::uint32_t>(), std::declval<bool>()));
276
277 template <typename T>
278 inline constexpr bool func_has_taylor_diff_dbl_v = std::is_same_v<detected_t<func_taylor_diff_dbl_t, T>, llvm::Value *>;
279
280 template <typename T>
281 using func_taylor_diff_ldbl_t = decltype(std::declval<std::add_lvalue_reference_t<const T>>().taylor_diff_ldbl(
282 std::declval<llvm_state &>(), std::declval<const std::vector<std::uint32_t> &>(),
283 std::declval<const std::vector<llvm::Value *> &>(), std::declval<llvm::Value *>(), std::declval<llvm::Value *>(),
284 std::declval<std::uint32_t>(), std::declval<std::uint32_t>(), std::declval<std::uint32_t>(),
285 std::declval<std::uint32_t>(), std::declval<bool>()));
286
287 template <typename T>
288 inline constexpr bool func_has_taylor_diff_ldbl_v
289 = std::is_same_v<detected_t<func_taylor_diff_ldbl_t, T>, llvm::Value *>;
290
291 #if defined(HEYOKA_HAVE_REAL128)
292
293 template <typename T>
294 using func_taylor_diff_f128_t = decltype(std::declval<std::add_lvalue_reference_t<const T>>().taylor_diff_f128(
295 std::declval<llvm_state &>(), std::declval<const std::vector<std::uint32_t> &>(),
296 std::declval<const std::vector<llvm::Value *> &>(), std::declval<llvm::Value *>(), std::declval<llvm::Value *>(),
297 std::declval<std::uint32_t>(), std::declval<std::uint32_t>(), std::declval<std::uint32_t>(),
298 std::declval<std::uint32_t>(), std::declval<bool>()));
299
300 template <typename T>
301 inline constexpr bool func_has_taylor_diff_f128_v
302 = std::is_same_v<detected_t<func_taylor_diff_f128_t, T>, llvm::Value *>;
303
304 #endif
305
306 template <typename T>
307 using func_taylor_c_diff_func_dbl_t
308 = decltype(std::declval<std::add_lvalue_reference_t<const T>>().taylor_c_diff_func_dbl(
309 std::declval<llvm_state &>(), std::declval<std::uint32_t>(), std::declval<std::uint32_t>(),
310 std::declval<bool>()));
311
312 template <typename T>
313 inline constexpr bool func_has_taylor_c_diff_func_dbl_v
314 = std::is_same_v<detected_t<func_taylor_c_diff_func_dbl_t, T>, llvm::Function *>;
315
316 template <typename T>
317 using func_taylor_c_diff_func_ldbl_t
318 = decltype(std::declval<std::add_lvalue_reference_t<const T>>().taylor_c_diff_func_ldbl(
319 std::declval<llvm_state &>(), std::declval<std::uint32_t>(), std::declval<std::uint32_t>(),
320 std::declval<bool>()));
321
322 template <typename T>
323 inline constexpr bool func_has_taylor_c_diff_func_ldbl_v
324 = std::is_same_v<detected_t<func_taylor_c_diff_func_ldbl_t, T>, llvm::Function *>;
325
326 #if defined(HEYOKA_HAVE_REAL128)
327
328 template <typename T>
329 using func_taylor_c_diff_func_f128_t
330 = decltype(std::declval<std::add_lvalue_reference_t<const T>>().taylor_c_diff_func_f128(
331 std::declval<llvm_state &>(), std::declval<std::uint32_t>(), std::declval<std::uint32_t>(),
332 std::declval<bool>()));
333
334 template <typename T>
335 inline constexpr bool func_has_taylor_c_diff_func_f128_v
336 = std::is_same_v<detected_t<func_taylor_c_diff_func_f128_t, T>, llvm::Function *>;
337
338 #endif
339
340 HEYOKA_DLL_PUBLIC void func_default_to_stream_impl(std::ostream &, const func_base &);
341
342 template <typename T>
343 struct HEYOKA_DLL_PUBLIC_INLINE_CLASS func_inner final : func_inner_base {
344 T m_value;
345
346 // We just need the def ctor, delete everything else.
347 func_inner() = default;
348 func_inner(const func_inner &) = delete;
349 func_inner(func_inner &&) = delete;
350 func_inner &operator=(const func_inner &) = delete;
351 func_inner &operator=(func_inner &&) = delete;
352
353 // Constructors from T (copy and move variants).
func_innerheyoka::detail::func_inner354 explicit func_inner(const T &x) : m_value(x) {}
func_innerheyoka::detail::func_inner355 explicit func_inner(T &&x) : m_value(std::move(x)) {}
356
357 // The clone function.
cloneheyoka::detail::func_inner358 std::unique_ptr<func_inner_base> clone() const final
359 {
360 return std::make_unique<func_inner>(m_value);
361 }
362
363 // Get the type at runtime.
get_type_indexheyoka::detail::func_inner364 std::type_index get_type_index() const final
365 {
366 return typeid(T);
367 }
368 // Raw getters for the internal instance.
get_ptrheyoka::detail::func_inner369 const void *get_ptr() const final
370 {
371 return &m_value;
372 }
get_ptrheyoka::detail::func_inner373 void *get_ptr() final
374 {
375 return &m_value;
376 }
377
get_nameheyoka::detail::func_inner378 const std::string &get_name() const final
379 {
380 // NOTE: make sure we are invoking the member functions
381 // from func_base (these functions could have been overriden
382 // in the derived class).
383 return static_cast<const func_base *>(&m_value)->get_name();
384 }
385
to_streamheyoka::detail::func_inner386 void to_stream(std::ostream &os) const final
387 {
388 if constexpr (func_has_to_stream_v<T>) {
389 m_value.to_stream(os);
390 } else {
391 func_default_to_stream_impl(os, static_cast<const func_base &>(m_value));
392 }
393 }
394
extra_equal_toheyoka::detail::func_inner395 bool extra_equal_to(const func &f) const final
396 {
397 if constexpr (func_has_extra_equal_to_v<T>) {
398 return m_value.extra_equal_to(f);
399 } else {
400 return true;
401 }
402 }
403
extra_hashheyoka::detail::func_inner404 std::size_t extra_hash() const final
405 {
406 if constexpr (func_has_extra_hash_v<T>) {
407 return m_value.extra_hash();
408 } else {
409 return 0;
410 }
411 }
412
argsheyoka::detail::func_inner413 const std::vector<expression> &args() const final
414 {
415 return static_cast<const func_base *>(&m_value)->args();
416 }
get_mutable_args_itheyoka::detail::func_inner417 std::pair<std::vector<expression>::iterator, std::vector<expression>::iterator> get_mutable_args_it() final
418 {
419 return static_cast<func_base *>(&m_value)->get_mutable_args_it();
420 }
421
422 // codegen.
codegen_dblheyoka::detail::func_inner423 llvm::Value *codegen_dbl(llvm_state &s, const std::vector<llvm::Value *> &v) const final
424 {
425 if constexpr (func_has_codegen_dbl_v<T>) {
426 return m_value.codegen_dbl(s, v);
427 } else {
428 throw not_implemented_error("double codegen is not implemented for the function '" + get_name() + "'");
429 }
430 }
codegen_ldblheyoka::detail::func_inner431 llvm::Value *codegen_ldbl(llvm_state &s, const std::vector<llvm::Value *> &v) const final
432 {
433 if constexpr (func_has_codegen_ldbl_v<T>) {
434 return m_value.codegen_ldbl(s, v);
435 } else {
436 throw not_implemented_error("long double codegen is not implemented for the function '" + get_name() + "'");
437 }
438 }
439 #if defined(HEYOKA_HAVE_REAL128)
codegen_f128heyoka::detail::func_inner440 llvm::Value *codegen_f128(llvm_state &s, const std::vector<llvm::Value *> &v) const final
441 {
442 if constexpr (func_has_codegen_f128_v<T>) {
443 return m_value.codegen_f128(s, v);
444 } else {
445 throw not_implemented_error("float128 codegen is not implemented for the function '" + get_name() + "'");
446 }
447 }
448 #endif
449
450 // diff.
has_diff_varheyoka::detail::func_inner451 bool has_diff_var() const final
452 {
453 return func_has_diff_var_v<T>;
454 }
455 expression diff(std::unordered_map<const void *, expression> &, const std::string &) const final;
has_diff_parheyoka::detail::func_inner456 bool has_diff_par() const final
457 {
458 return func_has_diff_par_v<T>;
459 }
460 expression diff(std::unordered_map<const void *, expression> &, const param &) const final;
461
462 // gradient.
has_gradientheyoka::detail::func_inner463 bool has_gradient() const final
464 {
465 return func_has_gradient_v<T>;
466 }
gradientheyoka::detail::func_inner467 std::vector<expression> gradient() const final
468 {
469 if constexpr (func_has_gradient_v<T>) {
470 return m_value.gradient();
471 }
472
473 // LCOV_EXCL_START
474 assert(false);
475 throw;
476 // LCOV_EXCL_STOP
477 }
478
479 // eval.
eval_dblheyoka::detail::func_inner480 double eval_dbl(const std::unordered_map<std::string, double> &m, const std::vector<double> &pars) const final
481 {
482 if constexpr (func_has_eval_dbl_v<T>) {
483 return m_value.eval_dbl(m, pars);
484 } else {
485 throw not_implemented_error("double eval is not implemented for the function '" + get_name() + "'");
486 }
487 }
eval_ldblheyoka::detail::func_inner488 long double eval_ldbl(const std::unordered_map<std::string, long double> &m,
489 const std::vector<long double> &pars) const final
490 {
491 if constexpr (func_has_eval_ldbl_v<T>) {
492 return m_value.eval_ldbl(m, pars);
493 } else {
494 throw not_implemented_error("long double eval is not implemented for the function '" + get_name() + "'");
495 }
496 }
497 #if defined(HEYOKA_HAVE_REAL128)
eval_f128heyoka::detail::func_inner498 mppp::real128 eval_f128(const std::unordered_map<std::string, mppp::real128> &m,
499 const std::vector<mppp::real128> &pars) const final
500 {
501 if constexpr (func_has_eval_f128_v<T>) {
502 return m_value.eval_f128(m, pars);
503 } else {
504 throw not_implemented_error("mppp::real128 eval is not implemented for the function '" + get_name() + "'");
505 }
506 }
507 #endif
eval_batch_dblheyoka::detail::func_inner508 void eval_batch_dbl(std::vector<double> &out, const std::unordered_map<std::string, std::vector<double>> &m,
509 const std::vector<double> &pars) const final
510 {
511 if constexpr (func_has_eval_batch_dbl_v<T>) {
512 m_value.eval_batch_dbl(out, m, pars);
513 } else {
514 throw not_implemented_error("double batch eval is not implemented for the function '" + get_name() + "'");
515 }
516 }
eval_num_dblheyoka::detail::func_inner517 double eval_num_dbl(const std::vector<double> &v) const final
518 {
519 if constexpr (func_has_eval_num_dbl_v<T>) {
520 return m_value.eval_num_dbl(v);
521 } else {
522 throw not_implemented_error("double numerical eval is not implemented for the function '" + get_name()
523 + "'");
524 }
525 }
deval_num_dblheyoka::detail::func_inner526 double deval_num_dbl(const std::vector<double> &v, std::vector<double>::size_type i) const final
527 {
528 if constexpr (func_has_deval_num_dbl_v<T>) {
529 return m_value.deval_num_dbl(v, i);
530 } else {
531 throw not_implemented_error("double numerical eval of the derivative is not implemented for the function '"
532 + get_name() + "'");
533 }
534 }
535
536 // Taylor.
taylor_decomposeheyoka::detail::func_inner537 taylor_dc_t::size_type taylor_decompose(taylor_dc_t &dc) && final
538 {
539 if constexpr (func_has_taylor_decompose_v<T>) {
540 return std::move(m_value).taylor_decompose(dc);
541 }
542
543 // LCOV_EXCL_START
544 assert(false);
545 throw;
546 // LCOV_EXCL_STOP
547 }
has_taylor_decomposeheyoka::detail::func_inner548 bool has_taylor_decompose() const final
549 {
550 return func_has_taylor_decompose_v<T>;
551 }
taylor_diff_dblheyoka::detail::func_inner552 llvm::Value *taylor_diff_dbl(llvm_state &s, const std::vector<std::uint32_t> &deps,
553 const std::vector<llvm::Value *> &arr, llvm::Value *par_ptr, llvm::Value *time_ptr,
554 std::uint32_t n_uvars, std::uint32_t order, std::uint32_t idx,
555 std::uint32_t batch_size, bool high_accuracy) const final
556 {
557 if constexpr (func_has_taylor_diff_dbl_v<T>) {
558 return m_value.taylor_diff_dbl(s, deps, arr, par_ptr, time_ptr, n_uvars, order, idx, batch_size,
559 high_accuracy);
560 } else {
561 throw not_implemented_error("double Taylor diff is not implemented for the function '" + get_name() + "'");
562 }
563 }
taylor_diff_ldblheyoka::detail::func_inner564 llvm::Value *taylor_diff_ldbl(llvm_state &s, const std::vector<std::uint32_t> &deps,
565 const std::vector<llvm::Value *> &arr, llvm::Value *par_ptr, llvm::Value *time_ptr,
566 std::uint32_t n_uvars, std::uint32_t order, std::uint32_t idx,
567 std::uint32_t batch_size, bool high_accuracy) const final
568 {
569 if constexpr (func_has_taylor_diff_ldbl_v<T>) {
570 return m_value.taylor_diff_ldbl(s, deps, arr, par_ptr, time_ptr, n_uvars, order, idx, batch_size,
571 high_accuracy);
572 } else {
573 throw not_implemented_error("long double Taylor diff is not implemented for the function '" + get_name()
574 + "'");
575 }
576 }
577 #if defined(HEYOKA_HAVE_REAL128)
taylor_diff_f128heyoka::detail::func_inner578 llvm::Value *taylor_diff_f128(llvm_state &s, const std::vector<std::uint32_t> &deps,
579 const std::vector<llvm::Value *> &arr, llvm::Value *par_ptr, llvm::Value *time_ptr,
580 std::uint32_t n_uvars, std::uint32_t order, std::uint32_t idx,
581 std::uint32_t batch_size, bool high_accuracy) const final
582 {
583 if constexpr (func_has_taylor_diff_f128_v<T>) {
584 return m_value.taylor_diff_f128(s, deps, arr, par_ptr, time_ptr, n_uvars, order, idx, batch_size,
585 high_accuracy);
586 } else {
587 throw not_implemented_error("float128 Taylor diff is not implemented for the function '" + get_name()
588 + "'");
589 }
590 }
591 #endif
taylor_c_diff_func_dblheyoka::detail::func_inner592 llvm::Function *taylor_c_diff_func_dbl(llvm_state &s, std::uint32_t n_uvars, std::uint32_t batch_size,
593 bool high_accuracy) const final
594 {
595 if constexpr (func_has_taylor_c_diff_func_dbl_v<T>) {
596 return m_value.taylor_c_diff_func_dbl(s, n_uvars, batch_size, high_accuracy);
597 } else {
598 throw not_implemented_error("double Taylor diff in compact mode is not implemented for the function '"
599 + get_name() + "'");
600 }
601 }
taylor_c_diff_func_ldblheyoka::detail::func_inner602 llvm::Function *taylor_c_diff_func_ldbl(llvm_state &s, std::uint32_t n_uvars, std::uint32_t batch_size,
603 bool high_accuracy) const final
604 {
605 if constexpr (func_has_taylor_c_diff_func_ldbl_v<T>) {
606 return m_value.taylor_c_diff_func_ldbl(s, n_uvars, batch_size, high_accuracy);
607 } else {
608 throw not_implemented_error("long double Taylor diff in compact mode is not implemented for the function '"
609 + get_name() + "'");
610 }
611 }
612 #if defined(HEYOKA_HAVE_REAL128)
taylor_c_diff_func_f128heyoka::detail::func_inner613 llvm::Function *taylor_c_diff_func_f128(llvm_state &s, std::uint32_t n_uvars, std::uint32_t batch_size,
614 bool high_accuracy) const final
615 {
616 if constexpr (func_has_taylor_c_diff_func_f128_v<T>) {
617 return m_value.taylor_c_diff_func_f128(s, n_uvars, batch_size, high_accuracy);
618 } else {
619 throw not_implemented_error("float128 Taylor diff in compact mode is not implemented for the function '"
620 + get_name() + "'");
621 }
622 }
623 #endif
624
625 private:
626 // Serialization.
627 friend class boost::serialization::access;
628 template <typename Archive>
serializeheyoka::detail::func_inner629 void serialize(Archive &ar, unsigned)
630 {
631 ar &boost::serialization::base_object<func_inner_base>(*this);
632 ar &m_value;
633 }
634 };
635
636 template <typename T>
637 using is_func = std::conjunction<std::is_same<T, uncvref_t<T>>, std::is_default_constructible<T>,
638 std::is_copy_constructible<T>, std::is_move_constructible<T>, std::is_destructible<T>,
639 // https://en.cppreference.com/w/cpp/concepts/derived_from
640 // NOTE: use add_pointer/add_cv in order to avoid
641 // issues if invoked with problematic types (e.g., void).
642 std::is_base_of<func_base, T>,
643 std::is_convertible<std::add_pointer_t<std::add_cv_t<T>>, const volatile func_base *>>;
644
645 } // namespace detail
646
647 HEYOKA_DLL_PUBLIC void swap(func &, func &) noexcept;
648
649 HEYOKA_DLL_PUBLIC std::ostream &operator<<(std::ostream &, const func &);
650
651 HEYOKA_DLL_PUBLIC std::size_t hash(const func &);
652
653 HEYOKA_DLL_PUBLIC bool operator==(const func &, const func &);
654 HEYOKA_DLL_PUBLIC bool operator!=(const func &, const func &);
655
656 class HEYOKA_DLL_PUBLIC func
657 {
658 friend HEYOKA_DLL_PUBLIC void swap(func &, func &) noexcept;
659 friend HEYOKA_DLL_PUBLIC std::ostream &operator<<(std::ostream &, const func &);
660 friend HEYOKA_DLL_PUBLIC std::size_t hash(const func &);
661 friend HEYOKA_DLL_PUBLIC bool operator==(const func &, const func &);
662
663 // Pointer to the inner base.
664 std::shared_ptr<detail::func_inner_base> m_ptr;
665
666 // Serialization.
667 friend class boost::serialization::access;
668 template <typename Archive>
serialize(Archive & ar,unsigned version)669 void serialize(Archive &ar, unsigned version)
670 {
671 // LCOV_EXCL_START
672 if (version == 0u) {
673 throw std::invalid_argument("Cannot load a function instance from an older archive");
674 }
675 // LCOV_EXCL_STOP
676
677 ar &m_ptr;
678 }
679
680 // Just two small helpers to make sure that whenever we require
681 // access to the pointer it actually points to something.
682 const detail::func_inner_base *ptr() const;
683 detail::func_inner_base *ptr();
684
685 // Private constructor used in the copy() function.
686 HEYOKA_DLL_LOCAL explicit func(std::unique_ptr<detail::func_inner_base>);
687
688 // Private helper to extract and check the gradient in the
689 // diff() implementations.
690 HEYOKA_DLL_LOCAL std::vector<expression> fetch_gradient(const std::string &) const;
691
692 template <typename T>
693 using generic_ctor_enabler
694 = std::enable_if_t<std::conjunction_v<std::negation<std::is_same<func, detail::uncvref_t<T>>>,
695 detail::is_func<detail::uncvref_t<T>>>,
696 int>;
697
698 public:
699 func();
700
701 template <typename T, generic_ctor_enabler<T &&> = 0>
func(T && x)702 explicit func(T &&x) : m_ptr(std::make_unique<detail::func_inner<detail::uncvref_t<T>>>(std::forward<T>(x)))
703 {
704 }
705
706 func(const func &);
707 func(func &&) noexcept;
708
709 func &operator=(const func &);
710 func &operator=(func &&) noexcept;
711
712 ~func();
713
714 // NOTE: this creates a new func containing
715 // a copy of the inner object: this means that
716 // the function arguments are shallow-copied and
717 // NOT deep-copied.
718 func copy() const;
719
720 // NOTE: like in pagmo, this may fail if invoked
721 // from different DLLs in certain situations (e.g.,
722 // Python bindings on OSX). I don't
723 // think this is currently an interesting use case
724 // for heyoka (as we don't provide a way of implementing
725 // new functions in Python), but, if it becomes a problem
726 // in the future, we can solve this in the same way as
727 // in pagmo.
728 template <typename T>
extract() const729 const T *extract() const noexcept
730 {
731 auto p = dynamic_cast<const detail::func_inner<T> *>(ptr());
732 return p == nullptr ? nullptr : &(p->m_value);
733 }
734 template <typename T>
extract()735 T *extract() noexcept
736 {
737 auto p = dynamic_cast<detail::func_inner<T> *>(ptr());
738 return p == nullptr ? nullptr : &(p->m_value);
739 }
740
741 std::type_index get_type_index() const;
742 const void *get_ptr() const;
743 void *get_ptr();
744
745 const std::string &get_name() const;
746
747 const std::vector<expression> &args() const;
748 std::pair<std::vector<expression>::iterator, std::vector<expression>::iterator> get_mutable_args_it();
749
750 llvm::Value *codegen_dbl(llvm_state &, const std::vector<llvm::Value *> &) const;
751 llvm::Value *codegen_ldbl(llvm_state &, const std::vector<llvm::Value *> &) const;
752 #if defined(HEYOKA_HAVE_REAL128)
753 llvm::Value *codegen_f128(llvm_state &, const std::vector<llvm::Value *> &) const;
754 #endif
755
756 expression diff(std::unordered_map<const void *, expression> &, const std::string &) const;
757 expression diff(std::unordered_map<const void *, expression> &, const param &) const;
758
759 double eval_dbl(const std::unordered_map<std::string, double> &, const std::vector<double> &) const;
760 long double eval_ldbl(const std::unordered_map<std::string, long double> &, const std::vector<long double> &) const;
761 #if defined(HEYOKA_HAVE_REAL128)
762 mppp::real128 eval_f128(const std::unordered_map<std::string, mppp::real128> &,
763 const std::vector<mppp::real128> &) const;
764 #endif
765
766 void eval_batch_dbl(std::vector<double> &, const std::unordered_map<std::string, std::vector<double>> &,
767 const std::vector<double> &) const;
768 double eval_num_dbl(const std::vector<double> &) const;
769 double deval_num_dbl(const std::vector<double> &, std::vector<double>::size_type) const;
770
771 taylor_dc_t::size_type taylor_decompose(std::unordered_map<const void *, taylor_dc_t::size_type> &,
772 taylor_dc_t &) const;
773 llvm::Value *taylor_diff_dbl(llvm_state &, const std::vector<std::uint32_t> &, const std::vector<llvm::Value *> &,
774 llvm::Value *, llvm::Value *, std::uint32_t, std::uint32_t, std::uint32_t,
775 std::uint32_t, bool) const;
776 llvm::Value *taylor_diff_ldbl(llvm_state &, const std::vector<std::uint32_t> &, const std::vector<llvm::Value *> &,
777 llvm::Value *, llvm::Value *, std::uint32_t, std::uint32_t, std::uint32_t,
778 std::uint32_t, bool) const;
779 #if defined(HEYOKA_HAVE_REAL128)
780 llvm::Value *taylor_diff_f128(llvm_state &, const std::vector<std::uint32_t> &, const std::vector<llvm::Value *> &,
781 llvm::Value *, llvm::Value *, std::uint32_t, std::uint32_t, std::uint32_t,
782 std::uint32_t, bool) const;
783 #endif
784 llvm::Function *taylor_c_diff_func_dbl(llvm_state &, std::uint32_t, std::uint32_t, bool) const;
785 llvm::Function *taylor_c_diff_func_ldbl(llvm_state &, std::uint32_t, std::uint32_t, bool) const;
786 #if defined(HEYOKA_HAVE_REAL128)
787 llvm::Function *taylor_c_diff_func_f128(llvm_state &, std::uint32_t, std::uint32_t, bool) const;
788 #endif
789 };
790
791 HEYOKA_DLL_PUBLIC double eval_dbl(const func &, const std::unordered_map<std::string, double> &,
792 const std::vector<double> &);
793 HEYOKA_DLL_PUBLIC long double eval_ldbl(const func &, const std::unordered_map<std::string, long double> &,
794 const std::vector<long double> &);
795 #if defined(HEYOKA_HAVE_REAL128)
796 HEYOKA_DLL_PUBLIC mppp::real128 eval_f128(const func &, const std::unordered_map<std::string, mppp::real128> &,
797 const std::vector<mppp::real128> &);
798 #endif
799
800 HEYOKA_DLL_PUBLIC void eval_batch_dbl(std::vector<double> &, const func &,
801 const std::unordered_map<std::string, std::vector<double>> &,
802 const std::vector<double> &);
803 HEYOKA_DLL_PUBLIC double eval_num_dbl(const func &, const std::vector<double> &);
804 HEYOKA_DLL_PUBLIC double deval_num_dbl(const func &, const std::vector<double> &, std::vector<double>::size_type);
805
806 HEYOKA_DLL_PUBLIC void update_connections(std::vector<std::vector<std::size_t>> &, const func &, std::size_t &);
807 HEYOKA_DLL_PUBLIC void update_node_values_dbl(std::vector<double> &, const func &,
808 const std::unordered_map<std::string, double> &,
809 const std::vector<std::vector<std::size_t>> &, std::size_t &);
810 HEYOKA_DLL_PUBLIC void update_grad_dbl(std::unordered_map<std::string, double> &, const func &,
811 const std::unordered_map<std::string, double> &, const std::vector<double> &,
812 const std::vector<std::vector<std::size_t>> &, std::size_t &, double);
813
814 namespace detail
815 {
816
817 // Helper to run the codegen of a function-like object with the arguments
818 // represented as a vector of LLVM values.
819 template <typename T, typename F>
codegen_from_values(llvm_state & s,const F & f,const std::vector<llvm::Value * > & args_v)820 inline llvm::Value *codegen_from_values(llvm_state &s, const F &f, const std::vector<llvm::Value *> &args_v)
821 {
822 if constexpr (std::is_same_v<T, double>) {
823 return f.codegen_dbl(s, args_v);
824 } else if constexpr (std::is_same_v<T, long double>) {
825 return f.codegen_ldbl(s, args_v);
826 #if defined(HEYOKA_HAVE_REAL128)
827 } else if constexpr (std::is_same_v<T, mppp::real128>) {
828 return f.codegen_f128(s, args_v);
829 #endif
830 } else {
831 static_assert(detail::always_false_v<T>, "Unhandled type.");
832 }
833 }
834
835 } // namespace detail
836
837 } // namespace heyoka
838
839 // Current archive version is 1.
840 BOOST_CLASS_VERSION(heyoka::func, 1)
841
842 // Macros for the registration of s11n for concrete functions.
843 #define HEYOKA_S11N_FUNC_EXPORT_KEY(f) BOOST_CLASS_EXPORT_KEY(heyoka::detail::func_inner<f>)
844
845 #define HEYOKA_S11N_FUNC_EXPORT_IMPLEMENT(f) BOOST_CLASS_EXPORT_IMPLEMENT(heyoka::detail::func_inner<f>)
846
847 #define HEYOKA_S11N_FUNC_EXPORT(f) \
848 HEYOKA_S11N_FUNC_EXPORT_KEY(f) \
849 HEYOKA_S11N_FUNC_EXPORT_IMPLEMENT(f)
850
851 #endif
852