1 // Copyright 2020, 2021 Francesco Biscani (bluescarni@gmail.com), Dario Izzo (dario.izzo@gmail.com)
2 //
3 // This file is part of the heyoka library.
4 //
5 // This Source Code Form is subject to the terms of the Mozilla
6 // Public License v. 2.0. If a copy of the MPL was not distributed
7 // with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
8 
9 #include <heyoka/config.hpp>
10 
11 #include <algorithm>
12 #include <cassert>
13 #include <cstddef>
14 #include <cstdint>
15 #include <functional>
16 #include <initializer_list>
17 #include <memory>
18 #include <ostream>
19 #include <stdexcept>
20 #include <string>
21 #include <type_traits>
22 #include <typeindex>
23 #include <unordered_map>
24 #include <unordered_set>
25 #include <utility>
26 #include <variant>
27 #include <vector>
28 
29 #include <boost/version.hpp>
30 
31 // NOTE: the header for hash_combine changed in version 1.67.
32 #if (BOOST_VERSION / 100000 > 1) || (BOOST_VERSION / 100000 == 1 && BOOST_VERSION / 100 % 1000 >= 67)
33 
34 #include <boost/container_hash/hash.hpp>
35 
36 #else
37 
38 #include <boost/functional/hash.hpp>
39 
40 #endif
41 
42 #include <fmt/format.h>
43 
44 #if defined(HEYOKA_HAVE_REAL128)
45 
46 #include <mp++/real128.hpp>
47 
48 #endif
49 
50 #include <heyoka/detail/fwd_decl.hpp>
51 #include <heyoka/detail/llvm_fwd.hpp>
52 #include <heyoka/detail/llvm_helpers.hpp>
53 #include <heyoka/detail/type_traits.hpp>
54 #include <heyoka/exceptions.hpp>
55 #include <heyoka/expression.hpp>
56 #include <heyoka/func.hpp>
57 #include <heyoka/math/sum.hpp>
58 #include <heyoka/number.hpp>
59 #include <heyoka/param.hpp>
60 #include <heyoka/variable.hpp>
61 
62 #if defined(_MSC_VER) && !defined(__clang__)
63 
64 // NOTE: MSVC has issues with the other "using"
65 // statement form.
66 using namespace fmt::literals;
67 
68 #else
69 
70 using fmt::literals::operator""_format;
71 
72 #endif
73 
74 namespace heyoka
75 {
76 
func_base(std::string name,std::vector<expression> args)77 func_base::func_base(std::string name, std::vector<expression> args) : m_name(std::move(name)), m_args(std::move(args))
78 {
79     if (m_name.empty()) {
80         throw std::invalid_argument("Cannot create a function with no name");
81     }
82 }
83 
84 func_base::func_base(const func_base &) = default;
85 
86 func_base::func_base(func_base &&) noexcept = default;
87 
88 func_base &func_base::operator=(const func_base &) = default;
89 
90 func_base &func_base::operator=(func_base &&) noexcept = default;
91 
92 func_base::~func_base() = default;
93 
get_name() const94 const std::string &func_base::get_name() const
95 {
96     return m_name;
97 }
98 
args() const99 const std::vector<expression> &func_base::args() const
100 {
101     return m_args;
102 }
103 
get_mutable_args_it()104 std::pair<std::vector<expression>::iterator, std::vector<expression>::iterator> func_base::get_mutable_args_it()
105 {
106     return {m_args.begin(), m_args.end()};
107 }
108 
109 namespace detail
110 {
111 
112 namespace
113 {
114 
115 // Helper to check if a vector of llvm values contains
116 // a nullptr.
llvm_valvec_has_null(const std::vector<llvm::Value * > & v)117 bool llvm_valvec_has_null(const std::vector<llvm::Value *> &v)
118 {
119     return std::any_of(v.begin(), v.end(), [](llvm::Value *p) { return p == nullptr; });
120 }
121 
122 } // namespace
123 
124 // Default implementation of to_stream() for func.
func_default_to_stream_impl(std::ostream & os,const func_base & f)125 void func_default_to_stream_impl(std::ostream &os, const func_base &f)
126 {
127     os << f.get_name() << '(';
128 
129     const auto &args = f.args();
130     for (decltype(args.size()) i = 0; i < args.size(); ++i) {
131         os << args[i];
132         if (i != args.size() - 1u) {
133             os << ", ";
134         }
135     }
136 
137     os << ')';
138 }
139 
140 func_inner_base::~func_inner_base() = default;
141 
142 namespace
143 {
144 
145 struct null_func : func_base {
null_funcheyoka::detail::__anon061a6d270311::null_func146     null_func() : func_base("null_func", {}) {}
147 };
148 
149 } // namespace
150 
151 } // namespace detail
152 
func(std::unique_ptr<detail::func_inner_base> p)153 func::func(std::unique_ptr<detail::func_inner_base> p) : m_ptr(p.release()) {}
154 
func()155 func::func() : func(detail::null_func{}) {}
156 
157 func::func(const func &) = default;
158 
159 func::func(func &&) noexcept = default;
160 
161 func &func::operator=(const func &) = default;
162 
163 func &func::operator=(func &&) noexcept = default;
164 
165 func::~func() = default;
166 
167 // NOTE: this creates a new func containing
168 // a copy of the inner object: this means that
169 // the function arguments are shallow-copied and
170 // NOT deep-copied.
copy() const171 func func::copy() const
172 {
173     return func{m_ptr->clone()};
174 }
175 
176 // Just two small helpers to make sure that whenever we require
177 // access to the pointer it actually points to something.
ptr() const178 const detail::func_inner_base *func::ptr() const
179 {
180     assert(m_ptr.get() != nullptr);
181     return m_ptr.get();
182 }
183 
ptr()184 detail::func_inner_base *func::ptr()
185 {
186     assert(m_ptr.get() != nullptr);
187     return m_ptr.get();
188 }
189 
get_type_index() const190 std::type_index func::get_type_index() const
191 {
192     return ptr()->get_type_index();
193 }
194 
get_ptr() const195 const void *func::get_ptr() const
196 {
197     return ptr()->get_ptr();
198 }
199 
get_ptr()200 void *func::get_ptr()
201 {
202     return ptr()->get_ptr();
203 }
204 
get_name() const205 const std::string &func::get_name() const
206 {
207     return ptr()->get_name();
208 }
209 
args() const210 const std::vector<expression> &func::args() const
211 {
212     return ptr()->args();
213 }
214 
get_mutable_args_it()215 std::pair<std::vector<expression>::iterator, std::vector<expression>::iterator> func::get_mutable_args_it()
216 {
217     return ptr()->get_mutable_args_it();
218 }
219 
codegen_dbl(llvm_state & s,const std::vector<llvm::Value * > & v) const220 llvm::Value *func::codegen_dbl(llvm_state &s, const std::vector<llvm::Value *> &v) const
221 {
222     if (v.size() != args().size()) {
223         throw std::invalid_argument(
224             "Inconsistent number of arguments supplied to the double codegen for the function '{}': {} arguments were expected, but {} arguments were provided instead"_format(
225                 get_name(), args().size(), v.size()));
226     }
227 
228     if (detail::llvm_valvec_has_null(v)) {
229         throw std::invalid_argument(
230             "Null pointer detected in the array of values passed to func::codegen_dbl() for the function '{}'"_format(
231                 get_name()));
232     }
233 
234     auto ret = ptr()->codegen_dbl(s, v);
235 
236     if (ret == nullptr) {
237         throw std::invalid_argument(
238             "The double codegen for the function '{}' returned a null pointer"_format(get_name()));
239     }
240 
241     return ret;
242 }
243 
codegen_ldbl(llvm_state & s,const std::vector<llvm::Value * > & v) const244 llvm::Value *func::codegen_ldbl(llvm_state &s, const std::vector<llvm::Value *> &v) const
245 {
246     if (v.size() != args().size()) {
247         throw std::invalid_argument(
248             "Inconsistent number of arguments supplied to the long double codegen for the function '{}': {} arguments were expected, but {} arguments were provided instead"_format(
249                 get_name(), args().size(), v.size()));
250     }
251 
252     if (detail::llvm_valvec_has_null(v)) {
253         throw std::invalid_argument(
254             "Null pointer detected in the array of values passed to func::codegen_ldbl() for the function '{}'"_format(
255                 get_name()));
256     }
257 
258     auto ret = ptr()->codegen_ldbl(s, v);
259 
260     if (ret == nullptr) {
261         throw std::invalid_argument(
262             "The long double codegen for the function '{}' returned a null pointer"_format(get_name()));
263     }
264 
265     return ret;
266 }
267 
268 #if defined(HEYOKA_HAVE_REAL128)
269 
codegen_f128(llvm_state & s,const std::vector<llvm::Value * > & v) const270 llvm::Value *func::codegen_f128(llvm_state &s, const std::vector<llvm::Value *> &v) const
271 {
272     if (v.size() != args().size()) {
273         throw std::invalid_argument(
274             "Inconsistent number of arguments supplied to the float128 codegen for the function '{}': {} arguments were expected, but {} arguments were provided instead"_format(
275                 get_name(), args().size(), v.size()));
276     }
277 
278     if (detail::llvm_valvec_has_null(v)) {
279         throw std::invalid_argument(
280             "Null pointer detected in the array of values passed to func::codegen_f128() for the function '{}'"_format(
281                 get_name()));
282     }
283 
284     auto ret = ptr()->codegen_f128(s, v);
285 
286     if (ret == nullptr) {
287         throw std::invalid_argument(
288             "The float128 codegen for the function '{}' returned a null pointer"_format(get_name()));
289     }
290 
291     return ret;
292 }
293 
294 #endif
295 
fetch_gradient(const std::string & target) const296 std::vector<expression> func::fetch_gradient(const std::string &target) const
297 {
298     // Check if we have the gradient.
299     if (!ptr()->has_gradient()) {
300         throw not_implemented_error("Cannot compute the derivative of the function '{}' with respect to a {}, because "
301                                     "the function does not provide neither a diff() "
302                                     "nor a gradient() member function"_format(get_name(), target));
303     }
304 
305     // Fetch the gradient.
306     auto grad = ptr()->gradient();
307 
308     // Check it.
309     const auto arity = args().size();
310     if (grad.size() != arity) {
311         throw std::invalid_argument(
312             "Inconsistent gradient returned by the function '{}': a vector of {} elements was expected, but the number of elements is {} instead"_format(
313                 get_name(), arity, grad.size()));
314     }
315 
316     return grad;
317 }
318 
diff(std::unordered_map<const void *,expression> & func_map,const std::string & s) const319 expression func::diff(std::unordered_map<const void *, expression> &func_map, const std::string &s) const
320 {
321     // Run the specialised diff implementation,
322     // if available.
323     if (ptr()->has_diff_var()) {
324         return ptr()->diff(func_map, s);
325     }
326 
327     const auto arity = args().size();
328 
329     // Fetch the gradient.
330     auto grad = fetch_gradient("variable");
331 
332     // Compute the total derivative.
333     std::vector<expression> prod;
334     prod.reserve(arity);
335     for (decltype(args().size()) i = 0; i < arity; ++i) {
336         prod.push_back(std::move(grad[i]) * detail::diff(func_map, args()[i], s));
337     }
338 
339     return sum(std::move(prod));
340 }
341 
diff(std::unordered_map<const void *,expression> & func_map,const param & p) const342 expression func::diff(std::unordered_map<const void *, expression> &func_map, const param &p) const
343 {
344     // Run the specialised diff implementation,
345     // if available.
346     if (ptr()->has_diff_par()) {
347         return ptr()->diff(func_map, p);
348     }
349 
350     const auto arity = args().size();
351 
352     // Fetch the gradient.
353     auto grad = fetch_gradient("parameter");
354 
355     // Compute the total derivative.
356     std::vector<expression> prod;
357     prod.reserve(arity);
358     for (decltype(args().size()) i = 0; i < arity; ++i) {
359         prod.push_back(std::move(grad[i]) * detail::diff(func_map, args()[i], p));
360     }
361 
362     return sum(std::move(prod));
363 }
364 
eval_dbl(const std::unordered_map<std::string,double> & m,const std::vector<double> & pars) const365 double func::eval_dbl(const std::unordered_map<std::string, double> &m, const std::vector<double> &pars) const
366 {
367     return ptr()->eval_dbl(m, pars);
368 }
369 
eval_ldbl(const std::unordered_map<std::string,long double> & m,const std::vector<long double> & pars) const370 long double func::eval_ldbl(const std::unordered_map<std::string, long double> &m,
371                             const std::vector<long double> &pars) const
372 {
373     return ptr()->eval_ldbl(m, pars);
374 }
375 
376 #if defined(HEYOKA_HAVE_REAL128)
eval_f128(const std::unordered_map<std::string,mppp::real128> & m,const std::vector<mppp::real128> & pars) const377 mppp::real128 func::eval_f128(const std::unordered_map<std::string, mppp::real128> &m,
378                               const std::vector<mppp::real128> &pars) const
379 {
380     return ptr()->eval_f128(m, pars);
381 }
382 #endif
eval_batch_dbl(std::vector<double> & out,const std::unordered_map<std::string,std::vector<double>> & m,const std::vector<double> & pars) const383 void func::eval_batch_dbl(std::vector<double> &out, const std::unordered_map<std::string, std::vector<double>> &m,
384                           const std::vector<double> &pars) const
385 {
386     ptr()->eval_batch_dbl(out, m, pars);
387 }
388 
eval_num_dbl(const std::vector<double> & v) const389 double func::eval_num_dbl(const std::vector<double> &v) const
390 {
391     if (v.size() != args().size()) {
392         throw std::invalid_argument(
393             "Inconsistent number of arguments supplied to the double numerical evaluation of the function '{}': {} arguments were expected, but {} arguments were provided instead"_format(
394                 get_name(), args().size(), v.size()));
395     }
396 
397     return ptr()->eval_num_dbl(v);
398 }
399 
deval_num_dbl(const std::vector<double> & v,std::vector<double>::size_type i) const400 double func::deval_num_dbl(const std::vector<double> &v, std::vector<double>::size_type i) const
401 {
402     if (v.size() != args().size()) {
403         throw std::invalid_argument(
404             "Inconsistent number of arguments supplied to the double numerical evaluation of the derivative of function '{}': {} arguments were expected, but {} arguments were provided instead"_format(
405                 get_name(), args().size(), v.size()));
406     }
407 
408     if (i >= v.size()) {
409         throw std::invalid_argument(
410             "Invalid index supplied to the double numerical evaluation of the derivative of function '{}': index {} was supplied, but the number of arguments is only {}"_format(
411                 get_name(), args().size(), v.size()));
412     }
413 
414     return ptr()->deval_num_dbl(v, i);
415 }
416 
417 namespace detail
418 {
419 
420 namespace
421 {
422 
423 // Perform the decomposition of the arguments of a function. After this operation,
424 // each argument will be either:
425 // - a variable,
426 // - a number,
427 // - a param.
func_td_args(func & fb,std::unordered_map<const void *,taylor_dc_t::size_type> & func_map,taylor_dc_t & dc)428 void func_td_args(func &fb, std::unordered_map<const void *, taylor_dc_t::size_type> &func_map, taylor_dc_t &dc)
429 {
430     for (auto r = fb.get_mutable_args_it(); r.first != r.second; ++r.first) {
431         if (const auto dres = taylor_decompose(func_map, *r.first, dc)) {
432             *r.first = expression{variable{"u_{}"_format(dres)}};
433         }
434 
435         assert(std::holds_alternative<variable>(r.first->value()) || std::holds_alternative<number>(r.first->value())
436                || std::holds_alternative<param>(r.first->value()));
437     }
438 }
439 
440 } // namespace
441 
442 } // namespace detail
443 
taylor_decompose(std::unordered_map<const void *,taylor_dc_t::size_type> & func_map,taylor_dc_t & dc) const444 taylor_dc_t::size_type func::taylor_decompose(std::unordered_map<const void *, taylor_dc_t::size_type> &func_map,
445                                               taylor_dc_t &dc) const
446 {
447     const auto f_id = get_ptr();
448 
449     if (auto it = func_map.find(f_id); it != func_map.end()) {
450         // We already decomposed the current function, fetch the result
451         // from the cache.
452         return it->second;
453     }
454 
455     // Make a shallow copy: this will be a new function,
456     // but its arguments will be shallow-copied from this.
457     auto f_copy = copy();
458 
459     // Decompose the arguments. This will overwrite
460     // the arguments in f_copy with their decomposition.
461     detail::func_td_args(f_copy, func_map, dc);
462 
463     // Run the decomposition.
464     taylor_dc_t::size_type ret = 0;
465     if (f_copy.ptr()->has_taylor_decompose()) {
466         // Custom implementation.
467         ret = std::move(*f_copy.ptr()).taylor_decompose(dc);
468     } else {
469         // Default implementation: append f_copy and return the index
470         // at which it was appended.
471         dc.emplace_back(std::move(f_copy), std::vector<std::uint32_t>{});
472         ret = dc.size() - 1u;
473     }
474 
475     if (ret == 0u) {
476         throw std::invalid_argument("The return value for the Taylor decomposition of a function can never be zero");
477     }
478 
479     if (ret >= dc.size()) {
480         throw std::invalid_argument(
481             "Invalid value returned by the Taylor decomposition function for the function '{}': "
482             "the return value is {}, which is not less than the current size of the decomposition "
483             "({})"_format(get_name(), ret, dc.size()));
484     }
485 
486     // Update the cache before exiting.
487     [[maybe_unused]] const auto [_, flag] = func_map.insert(std::pair{f_id, ret});
488     assert(flag);
489 
490     return ret;
491 }
492 
taylor_diff_dbl(llvm_state & s,const std::vector<std::uint32_t> & deps,const std::vector<llvm::Value * > & arr,llvm::Value * par_ptr,llvm::Value * time_ptr,std::uint32_t n_uvars,std::uint32_t order,std::uint32_t idx,std::uint32_t batch_size,bool high_accuracy) const493 llvm::Value *func::taylor_diff_dbl(llvm_state &s, const std::vector<std::uint32_t> &deps,
494                                    const std::vector<llvm::Value *> &arr, llvm::Value *par_ptr, llvm::Value *time_ptr,
495                                    std::uint32_t n_uvars, std::uint32_t order, std::uint32_t idx,
496                                    std::uint32_t batch_size, bool high_accuracy) const
497 {
498     if (par_ptr == nullptr) {
499         throw std::invalid_argument(
500             "Null par_ptr detected in func::taylor_diff_dbl() for the function '{}'"_format(get_name()));
501     }
502 
503     if (time_ptr == nullptr) {
504         throw std::invalid_argument(
505             "Null time_ptr detected in func::taylor_diff_dbl() for the function '{}'"_format(get_name()));
506     }
507 
508     if (batch_size == 0u) {
509         throw std::invalid_argument(
510             "Zero batch size detected in func::taylor_diff_dbl() for the function '{}'"_format(get_name()));
511     }
512 
513     if (n_uvars == 0u) {
514         throw std::invalid_argument(
515             "Zero number of u variables detected in func::taylor_diff_dbl() for the function '{}'"_format(get_name()));
516     }
517 
518     auto retval
519         = ptr()->taylor_diff_dbl(s, deps, arr, par_ptr, time_ptr, n_uvars, order, idx, batch_size, high_accuracy);
520 
521     if (retval == nullptr) {
522         throw std::invalid_argument(
523             "Null return value detected in func::taylor_diff_dbl() for the function '{}'"_format(get_name()));
524     }
525 
526     return retval;
527 }
528 
taylor_diff_ldbl(llvm_state & s,const std::vector<std::uint32_t> & deps,const std::vector<llvm::Value * > & arr,llvm::Value * par_ptr,llvm::Value * time_ptr,std::uint32_t n_uvars,std::uint32_t order,std::uint32_t idx,std::uint32_t batch_size,bool high_accuracy) const529 llvm::Value *func::taylor_diff_ldbl(llvm_state &s, const std::vector<std::uint32_t> &deps,
530                                     const std::vector<llvm::Value *> &arr, llvm::Value *par_ptr, llvm::Value *time_ptr,
531                                     std::uint32_t n_uvars, std::uint32_t order, std::uint32_t idx,
532                                     std::uint32_t batch_size, bool high_accuracy) const
533 {
534     if (par_ptr == nullptr) {
535         throw std::invalid_argument(
536             "Null par_ptr detected in func::taylor_diff_ldbl() for the function '{}'"_format(get_name()));
537     }
538 
539     if (time_ptr == nullptr) {
540         throw std::invalid_argument(
541             "Null time_ptr detected in func::taylor_diff_ldbl() for the function '{}'"_format(get_name()));
542     }
543 
544     if (batch_size == 0u) {
545         throw std::invalid_argument(
546             "Zero batch size detected in func::taylor_diff_ldbl() for the function '{}'"_format(get_name()));
547     }
548 
549     if (n_uvars == 0u) {
550         throw std::invalid_argument(
551             "Zero number of u variables detected in func::taylor_diff_ldbl() for the function '{}'"_format(get_name()));
552     }
553 
554     auto retval
555         = ptr()->taylor_diff_ldbl(s, deps, arr, par_ptr, time_ptr, n_uvars, order, idx, batch_size, high_accuracy);
556 
557     if (retval == nullptr) {
558         throw std::invalid_argument(
559             "Null return value detected in func::taylor_diff_ldbl() for the function '{}'"_format(get_name()));
560     }
561 
562     return retval;
563 }
564 
565 #if defined(HEYOKA_HAVE_REAL128)
566 
taylor_diff_f128(llvm_state & s,const std::vector<std::uint32_t> & deps,const std::vector<llvm::Value * > & arr,llvm::Value * par_ptr,llvm::Value * time_ptr,std::uint32_t n_uvars,std::uint32_t order,std::uint32_t idx,std::uint32_t batch_size,bool high_accuracy) const567 llvm::Value *func::taylor_diff_f128(llvm_state &s, const std::vector<std::uint32_t> &deps,
568                                     const std::vector<llvm::Value *> &arr, llvm::Value *par_ptr, llvm::Value *time_ptr,
569                                     std::uint32_t n_uvars, std::uint32_t order, std::uint32_t idx,
570                                     std::uint32_t batch_size, bool high_accuracy) const
571 {
572     if (par_ptr == nullptr) {
573         throw std::invalid_argument(
574             "Null par_ptr detected in func::taylor_diff_f128() for the function '{}'"_format(get_name()));
575     }
576 
577     if (time_ptr == nullptr) {
578         throw std::invalid_argument(
579             "Null time_ptr detected in func::taylor_diff_f128() for the function '{}'"_format(get_name()));
580     }
581 
582     if (batch_size == 0u) {
583         throw std::invalid_argument(
584             "Zero batch size detected in func::taylor_diff_f128() for the function '{}'"_format(get_name()));
585     }
586 
587     if (n_uvars == 0u) {
588         throw std::invalid_argument(
589             "Zero number of u variables detected in func::taylor_diff_f128() for the function '{}'"_format(get_name()));
590     }
591 
592     auto retval
593         = ptr()->taylor_diff_f128(s, deps, arr, par_ptr, time_ptr, n_uvars, order, idx, batch_size, high_accuracy);
594 
595     if (retval == nullptr) {
596         throw std::invalid_argument(
597             "Null return value detected in func::taylor_diff_f128() for the function '{}'"_format(get_name()));
598     }
599 
600     return retval;
601 }
602 
603 #endif
604 
taylor_c_diff_func_dbl(llvm_state & s,std::uint32_t n_uvars,std::uint32_t batch_size,bool high_accuracy) const605 llvm::Function *func::taylor_c_diff_func_dbl(llvm_state &s, std::uint32_t n_uvars, std::uint32_t batch_size,
606                                              bool high_accuracy) const
607 {
608     if (batch_size == 0u) {
609         throw std::invalid_argument(
610             "Zero batch size detected in func::taylor_c_diff_func_dbl() for the function '{}'"_format(get_name()));
611     }
612 
613     if (n_uvars == 0u) {
614         throw std::invalid_argument(
615             "Zero number of u variables detected in func::taylor_c_diff_func_dbl() for the function '{}'"_format(
616                 get_name()));
617     }
618 
619     auto retval = ptr()->taylor_c_diff_func_dbl(s, n_uvars, batch_size, high_accuracy);
620 
621     if (retval == nullptr) {
622         throw std::invalid_argument(
623             "Null return value detected in func::taylor_c_diff_func_dbl() for the function '{}'"_format(get_name()));
624     }
625 
626     return retval;
627 }
628 
taylor_c_diff_func_ldbl(llvm_state & s,std::uint32_t n_uvars,std::uint32_t batch_size,bool high_accuracy) const629 llvm::Function *func::taylor_c_diff_func_ldbl(llvm_state &s, std::uint32_t n_uvars, std::uint32_t batch_size,
630                                               bool high_accuracy) const
631 {
632     if (batch_size == 0u) {
633         throw std::invalid_argument(
634             "Zero batch size detected in func::taylor_c_diff_func_ldbl() for the function '{}'"_format(get_name()));
635     }
636 
637     if (n_uvars == 0u) {
638         throw std::invalid_argument(
639             "Zero number of u variables detected in func::taylor_c_diff_func_ldbl() for the function '{}'"_format(
640                 get_name()));
641     }
642 
643     auto retval = ptr()->taylor_c_diff_func_ldbl(s, n_uvars, batch_size, high_accuracy);
644 
645     if (retval == nullptr) {
646         throw std::invalid_argument(
647             "Null return value detected in func::taylor_c_diff_func_ldbl() for the function '{}'"_format(get_name()));
648     }
649 
650     return retval;
651 }
652 
653 #if defined(HEYOKA_HAVE_REAL128)
654 
taylor_c_diff_func_f128(llvm_state & s,std::uint32_t n_uvars,std::uint32_t batch_size,bool high_accuracy) const655 llvm::Function *func::taylor_c_diff_func_f128(llvm_state &s, std::uint32_t n_uvars, std::uint32_t batch_size,
656                                               bool high_accuracy) const
657 {
658     if (batch_size == 0u) {
659         throw std::invalid_argument(
660             "Zero batch size detected in func::taylor_c_diff_func_f128() for the function '{}'"_format(get_name()));
661     }
662 
663     if (n_uvars == 0u) {
664         throw std::invalid_argument(
665             "Zero number of u variables detected in func::taylor_c_diff_func_f128() for the function '{}'"_format(
666                 get_name()));
667     }
668 
669     auto retval = ptr()->taylor_c_diff_func_f128(s, n_uvars, batch_size, high_accuracy);
670 
671     if (retval == nullptr) {
672         throw std::invalid_argument(
673             "Null return value detected in func::taylor_c_diff_func_f128() for the function '{}'"_format(get_name()));
674     }
675 
676     return retval;
677 }
678 
679 #endif
680 
swap(func & a,func & b)681 void swap(func &a, func &b) noexcept
682 {
683     std::swap(a.m_ptr, b.m_ptr);
684 }
685 
operator <<(std::ostream & os,const func & f)686 std::ostream &operator<<(std::ostream &os, const func &f)
687 {
688     f.ptr()->to_stream(os);
689 
690     return os;
691 }
692 
hash(const func & f)693 std::size_t hash(const func &f)
694 {
695     // NOTE: the initial hash value is computed by combining the hash values of:
696     // - the function name,
697     // - the function inner type index,
698     // - the arguments' hashes.
699     std::size_t seed = std::hash<std::string>{}(f.get_name());
700 
701     boost::hash_combine(seed, f.get_type_index());
702 
703     for (const auto &arg : f.args()) {
704         boost::hash_combine(seed, hash(arg));
705     }
706 
707     // Combine with the extra hash value too.
708     boost::hash_combine(seed, f.ptr()->extra_hash());
709 
710     return seed;
711 }
712 
operator ==(const func & a,const func & b)713 bool operator==(const func &a, const func &b)
714 {
715     // Check if the underlying object is the same.
716     if (a.m_ptr == b.m_ptr) {
717         return true;
718     }
719 
720     // NOTE: the initial comparison considers:
721     // - the function name,
722     // - the function inner type index,
723     // - the arguments.
724     // If they are all equal, the extra equality comparison logic
725     // is also run.
726     if (a.get_name() == b.get_name() && a.get_type_index() == b.get_type_index() && a.args() == b.args()) {
727         return a.ptr()->extra_equal_to(b);
728     } else {
729         return false;
730     }
731 }
732 
operator !=(const func & a,const func & b)733 bool operator!=(const func &a, const func &b)
734 {
735     return !(a == b);
736 }
737 
eval_dbl(const func & f,const std::unordered_map<std::string,double> & map,const std::vector<double> & pars)738 double eval_dbl(const func &f, const std::unordered_map<std::string, double> &map, const std::vector<double> &pars)
739 {
740     return f.eval_dbl(map, pars);
741 }
742 
eval_ldbl(const func & f,const std::unordered_map<std::string,long double> & map,const std::vector<long double> & pars)743 long double eval_ldbl(const func &f, const std::unordered_map<std::string, long double> &map,
744                       const std::vector<long double> &pars)
745 {
746     return f.eval_ldbl(map, pars);
747 }
748 
749 #if defined(HEYOKA_HAVE_REAL128)
750 
eval_f128(const func & f,const std::unordered_map<std::string,mppp::real128> & map,const std::vector<mppp::real128> & pars)751 mppp::real128 eval_f128(const func &f, const std::unordered_map<std::string, mppp::real128> &map,
752                         const std::vector<mppp::real128> &pars)
753 {
754     return f.eval_f128(map, pars);
755 }
756 
757 #endif
758 
eval_batch_dbl(std::vector<double> & out_values,const func & f,const std::unordered_map<std::string,std::vector<double>> & map,const std::vector<double> & pars)759 void eval_batch_dbl(std::vector<double> &out_values, const func &f,
760                     const std::unordered_map<std::string, std::vector<double>> &map, const std::vector<double> &pars)
761 {
762     f.eval_batch_dbl(out_values, map, pars);
763 }
764 
eval_num_dbl(const func & f,const std::vector<double> & in)765 double eval_num_dbl(const func &f, const std::vector<double> &in)
766 {
767     return f.eval_num_dbl(in);
768 }
769 
deval_num_dbl(const func & f,const std::vector<double> & in,std::vector<double>::size_type d)770 double deval_num_dbl(const func &f, const std::vector<double> &in, std::vector<double>::size_type d)
771 {
772     return f.deval_num_dbl(in, d);
773 }
774 
update_node_values_dbl(std::vector<double> & node_values,const func & f,const std::unordered_map<std::string,double> & map,const std::vector<std::vector<std::size_t>> & node_connections,std::size_t & node_counter)775 void update_node_values_dbl(std::vector<double> &node_values, const func &f,
776                             const std::unordered_map<std::string, double> &map,
777                             const std::vector<std::vector<std::size_t>> &node_connections, std::size_t &node_counter)
778 {
779     const auto node_id = node_counter;
780     node_counter++;
781     // We have to recurse first as to make sure node_values is filled before being accessed later.
782     for (decltype(f.args().size()) i = 0u; i < f.args().size(); ++i) {
783         update_node_values_dbl(node_values, f.args()[i], map, node_connections, node_counter);
784     }
785     // Then we compute
786     std::vector<double> in_values(f.args().size());
787     for (decltype(f.args().size()) i = 0u; i < f.args().size(); ++i) {
788         in_values[i] = node_values[node_connections[node_id][i]];
789     }
790     node_values[node_id] = eval_num_dbl(f, in_values);
791 }
792 
update_grad_dbl(std::unordered_map<std::string,double> & grad,const func & f,const std::unordered_map<std::string,double> & map,const std::vector<double> & node_values,const std::vector<std::vector<std::size_t>> & node_connections,std::size_t & node_counter,double acc)793 void update_grad_dbl(std::unordered_map<std::string, double> &grad, const func &f,
794                      const std::unordered_map<std::string, double> &map, const std::vector<double> &node_values,
795                      const std::vector<std::vector<std::size_t>> &node_connections, std::size_t &node_counter,
796                      double acc)
797 {
798     const auto node_id = node_counter;
799     node_counter++;
800     std::vector<double> in_values(f.args().size());
801     for (decltype(f.args().size()) i = 0u; i < f.args().size(); ++i) {
802         in_values[i] = node_values[node_connections[node_id][i]];
803     }
804     for (decltype(f.args().size()) i = 0u; i < f.args().size(); ++i) {
805         auto value = deval_num_dbl(f, in_values, i);
806         update_grad_dbl(grad, f.args()[i], map, node_values, node_connections, node_counter, acc * value);
807     }
808 }
809 
update_connections(std::vector<std::vector<std::size_t>> & node_connections,const func & f,std::size_t & node_counter)810 void update_connections(std::vector<std::vector<std::size_t>> &node_connections, const func &f,
811                         std::size_t &node_counter)
812 {
813     const auto node_id = node_counter;
814     node_counter++;
815     node_connections.push_back(std::vector<std::size_t>(f.args().size()));
816     for (decltype(f.args().size()) i = 0u; i < f.args().size(); ++i) {
817         node_connections[node_id][i] = node_counter;
818         update_connections(node_connections, f.args()[i], node_counter);
819     };
820 }
821 
822 } // namespace heyoka
823