1 // Copyright 2020, 2021 Francesco Biscani (bluescarni@gmail.com), Dario Izzo (dario.izzo@gmail.com)
2 //
3 // This file is part of the heyoka library.
4 //
5 // This Source Code Form is subject to the terms of the Mozilla
6 // Public License v. 2.0. If a copy of the MPL was not distributed
7 // with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
8
9 #include <heyoka/config.hpp>
10
11 #include <algorithm>
12 #include <cassert>
13 #include <cstddef>
14 #include <cstdint>
15 #include <functional>
16 #include <initializer_list>
17 #include <memory>
18 #include <ostream>
19 #include <stdexcept>
20 #include <string>
21 #include <type_traits>
22 #include <typeindex>
23 #include <unordered_map>
24 #include <unordered_set>
25 #include <utility>
26 #include <variant>
27 #include <vector>
28
29 #include <boost/version.hpp>
30
31 // NOTE: the header for hash_combine changed in version 1.67.
32 #if (BOOST_VERSION / 100000 > 1) || (BOOST_VERSION / 100000 == 1 && BOOST_VERSION / 100 % 1000 >= 67)
33
34 #include <boost/container_hash/hash.hpp>
35
36 #else
37
38 #include <boost/functional/hash.hpp>
39
40 #endif
41
42 #include <fmt/format.h>
43
44 #if defined(HEYOKA_HAVE_REAL128)
45
46 #include <mp++/real128.hpp>
47
48 #endif
49
50 #include <heyoka/detail/fwd_decl.hpp>
51 #include <heyoka/detail/llvm_fwd.hpp>
52 #include <heyoka/detail/llvm_helpers.hpp>
53 #include <heyoka/detail/type_traits.hpp>
54 #include <heyoka/exceptions.hpp>
55 #include <heyoka/expression.hpp>
56 #include <heyoka/func.hpp>
57 #include <heyoka/math/sum.hpp>
58 #include <heyoka/number.hpp>
59 #include <heyoka/param.hpp>
60 #include <heyoka/variable.hpp>
61
62 #if defined(_MSC_VER) && !defined(__clang__)
63
64 // NOTE: MSVC has issues with the other "using"
65 // statement form.
66 using namespace fmt::literals;
67
68 #else
69
70 using fmt::literals::operator""_format;
71
72 #endif
73
74 namespace heyoka
75 {
76
func_base(std::string name,std::vector<expression> args)77 func_base::func_base(std::string name, std::vector<expression> args) : m_name(std::move(name)), m_args(std::move(args))
78 {
79 if (m_name.empty()) {
80 throw std::invalid_argument("Cannot create a function with no name");
81 }
82 }
83
84 func_base::func_base(const func_base &) = default;
85
86 func_base::func_base(func_base &&) noexcept = default;
87
88 func_base &func_base::operator=(const func_base &) = default;
89
90 func_base &func_base::operator=(func_base &&) noexcept = default;
91
92 func_base::~func_base() = default;
93
get_name() const94 const std::string &func_base::get_name() const
95 {
96 return m_name;
97 }
98
args() const99 const std::vector<expression> &func_base::args() const
100 {
101 return m_args;
102 }
103
get_mutable_args_it()104 std::pair<std::vector<expression>::iterator, std::vector<expression>::iterator> func_base::get_mutable_args_it()
105 {
106 return {m_args.begin(), m_args.end()};
107 }
108
109 namespace detail
110 {
111
112 namespace
113 {
114
115 // Helper to check if a vector of llvm values contains
116 // a nullptr.
llvm_valvec_has_null(const std::vector<llvm::Value * > & v)117 bool llvm_valvec_has_null(const std::vector<llvm::Value *> &v)
118 {
119 return std::any_of(v.begin(), v.end(), [](llvm::Value *p) { return p == nullptr; });
120 }
121
122 } // namespace
123
124 // Default implementation of to_stream() for func.
func_default_to_stream_impl(std::ostream & os,const func_base & f)125 void func_default_to_stream_impl(std::ostream &os, const func_base &f)
126 {
127 os << f.get_name() << '(';
128
129 const auto &args = f.args();
130 for (decltype(args.size()) i = 0; i < args.size(); ++i) {
131 os << args[i];
132 if (i != args.size() - 1u) {
133 os << ", ";
134 }
135 }
136
137 os << ')';
138 }
139
140 func_inner_base::~func_inner_base() = default;
141
142 namespace
143 {
144
145 struct null_func : func_base {
null_funcheyoka::detail::__anon061a6d270311::null_func146 null_func() : func_base("null_func", {}) {}
147 };
148
149 } // namespace
150
151 } // namespace detail
152
func(std::unique_ptr<detail::func_inner_base> p)153 func::func(std::unique_ptr<detail::func_inner_base> p) : m_ptr(p.release()) {}
154
func()155 func::func() : func(detail::null_func{}) {}
156
157 func::func(const func &) = default;
158
159 func::func(func &&) noexcept = default;
160
161 func &func::operator=(const func &) = default;
162
163 func &func::operator=(func &&) noexcept = default;
164
165 func::~func() = default;
166
167 // NOTE: this creates a new func containing
168 // a copy of the inner object: this means that
169 // the function arguments are shallow-copied and
170 // NOT deep-copied.
copy() const171 func func::copy() const
172 {
173 return func{m_ptr->clone()};
174 }
175
176 // Just two small helpers to make sure that whenever we require
177 // access to the pointer it actually points to something.
ptr() const178 const detail::func_inner_base *func::ptr() const
179 {
180 assert(m_ptr.get() != nullptr);
181 return m_ptr.get();
182 }
183
ptr()184 detail::func_inner_base *func::ptr()
185 {
186 assert(m_ptr.get() != nullptr);
187 return m_ptr.get();
188 }
189
get_type_index() const190 std::type_index func::get_type_index() const
191 {
192 return ptr()->get_type_index();
193 }
194
get_ptr() const195 const void *func::get_ptr() const
196 {
197 return ptr()->get_ptr();
198 }
199
get_ptr()200 void *func::get_ptr()
201 {
202 return ptr()->get_ptr();
203 }
204
get_name() const205 const std::string &func::get_name() const
206 {
207 return ptr()->get_name();
208 }
209
args() const210 const std::vector<expression> &func::args() const
211 {
212 return ptr()->args();
213 }
214
get_mutable_args_it()215 std::pair<std::vector<expression>::iterator, std::vector<expression>::iterator> func::get_mutable_args_it()
216 {
217 return ptr()->get_mutable_args_it();
218 }
219
codegen_dbl(llvm_state & s,const std::vector<llvm::Value * > & v) const220 llvm::Value *func::codegen_dbl(llvm_state &s, const std::vector<llvm::Value *> &v) const
221 {
222 if (v.size() != args().size()) {
223 throw std::invalid_argument(
224 "Inconsistent number of arguments supplied to the double codegen for the function '{}': {} arguments were expected, but {} arguments were provided instead"_format(
225 get_name(), args().size(), v.size()));
226 }
227
228 if (detail::llvm_valvec_has_null(v)) {
229 throw std::invalid_argument(
230 "Null pointer detected in the array of values passed to func::codegen_dbl() for the function '{}'"_format(
231 get_name()));
232 }
233
234 auto ret = ptr()->codegen_dbl(s, v);
235
236 if (ret == nullptr) {
237 throw std::invalid_argument(
238 "The double codegen for the function '{}' returned a null pointer"_format(get_name()));
239 }
240
241 return ret;
242 }
243
codegen_ldbl(llvm_state & s,const std::vector<llvm::Value * > & v) const244 llvm::Value *func::codegen_ldbl(llvm_state &s, const std::vector<llvm::Value *> &v) const
245 {
246 if (v.size() != args().size()) {
247 throw std::invalid_argument(
248 "Inconsistent number of arguments supplied to the long double codegen for the function '{}': {} arguments were expected, but {} arguments were provided instead"_format(
249 get_name(), args().size(), v.size()));
250 }
251
252 if (detail::llvm_valvec_has_null(v)) {
253 throw std::invalid_argument(
254 "Null pointer detected in the array of values passed to func::codegen_ldbl() for the function '{}'"_format(
255 get_name()));
256 }
257
258 auto ret = ptr()->codegen_ldbl(s, v);
259
260 if (ret == nullptr) {
261 throw std::invalid_argument(
262 "The long double codegen for the function '{}' returned a null pointer"_format(get_name()));
263 }
264
265 return ret;
266 }
267
268 #if defined(HEYOKA_HAVE_REAL128)
269
codegen_f128(llvm_state & s,const std::vector<llvm::Value * > & v) const270 llvm::Value *func::codegen_f128(llvm_state &s, const std::vector<llvm::Value *> &v) const
271 {
272 if (v.size() != args().size()) {
273 throw std::invalid_argument(
274 "Inconsistent number of arguments supplied to the float128 codegen for the function '{}': {} arguments were expected, but {} arguments were provided instead"_format(
275 get_name(), args().size(), v.size()));
276 }
277
278 if (detail::llvm_valvec_has_null(v)) {
279 throw std::invalid_argument(
280 "Null pointer detected in the array of values passed to func::codegen_f128() for the function '{}'"_format(
281 get_name()));
282 }
283
284 auto ret = ptr()->codegen_f128(s, v);
285
286 if (ret == nullptr) {
287 throw std::invalid_argument(
288 "The float128 codegen for the function '{}' returned a null pointer"_format(get_name()));
289 }
290
291 return ret;
292 }
293
294 #endif
295
fetch_gradient(const std::string & target) const296 std::vector<expression> func::fetch_gradient(const std::string &target) const
297 {
298 // Check if we have the gradient.
299 if (!ptr()->has_gradient()) {
300 throw not_implemented_error("Cannot compute the derivative of the function '{}' with respect to a {}, because "
301 "the function does not provide neither a diff() "
302 "nor a gradient() member function"_format(get_name(), target));
303 }
304
305 // Fetch the gradient.
306 auto grad = ptr()->gradient();
307
308 // Check it.
309 const auto arity = args().size();
310 if (grad.size() != arity) {
311 throw std::invalid_argument(
312 "Inconsistent gradient returned by the function '{}': a vector of {} elements was expected, but the number of elements is {} instead"_format(
313 get_name(), arity, grad.size()));
314 }
315
316 return grad;
317 }
318
diff(std::unordered_map<const void *,expression> & func_map,const std::string & s) const319 expression func::diff(std::unordered_map<const void *, expression> &func_map, const std::string &s) const
320 {
321 // Run the specialised diff implementation,
322 // if available.
323 if (ptr()->has_diff_var()) {
324 return ptr()->diff(func_map, s);
325 }
326
327 const auto arity = args().size();
328
329 // Fetch the gradient.
330 auto grad = fetch_gradient("variable");
331
332 // Compute the total derivative.
333 std::vector<expression> prod;
334 prod.reserve(arity);
335 for (decltype(args().size()) i = 0; i < arity; ++i) {
336 prod.push_back(std::move(grad[i]) * detail::diff(func_map, args()[i], s));
337 }
338
339 return sum(std::move(prod));
340 }
341
diff(std::unordered_map<const void *,expression> & func_map,const param & p) const342 expression func::diff(std::unordered_map<const void *, expression> &func_map, const param &p) const
343 {
344 // Run the specialised diff implementation,
345 // if available.
346 if (ptr()->has_diff_par()) {
347 return ptr()->diff(func_map, p);
348 }
349
350 const auto arity = args().size();
351
352 // Fetch the gradient.
353 auto grad = fetch_gradient("parameter");
354
355 // Compute the total derivative.
356 std::vector<expression> prod;
357 prod.reserve(arity);
358 for (decltype(args().size()) i = 0; i < arity; ++i) {
359 prod.push_back(std::move(grad[i]) * detail::diff(func_map, args()[i], p));
360 }
361
362 return sum(std::move(prod));
363 }
364
eval_dbl(const std::unordered_map<std::string,double> & m,const std::vector<double> & pars) const365 double func::eval_dbl(const std::unordered_map<std::string, double> &m, const std::vector<double> &pars) const
366 {
367 return ptr()->eval_dbl(m, pars);
368 }
369
eval_ldbl(const std::unordered_map<std::string,long double> & m,const std::vector<long double> & pars) const370 long double func::eval_ldbl(const std::unordered_map<std::string, long double> &m,
371 const std::vector<long double> &pars) const
372 {
373 return ptr()->eval_ldbl(m, pars);
374 }
375
376 #if defined(HEYOKA_HAVE_REAL128)
eval_f128(const std::unordered_map<std::string,mppp::real128> & m,const std::vector<mppp::real128> & pars) const377 mppp::real128 func::eval_f128(const std::unordered_map<std::string, mppp::real128> &m,
378 const std::vector<mppp::real128> &pars) const
379 {
380 return ptr()->eval_f128(m, pars);
381 }
382 #endif
eval_batch_dbl(std::vector<double> & out,const std::unordered_map<std::string,std::vector<double>> & m,const std::vector<double> & pars) const383 void func::eval_batch_dbl(std::vector<double> &out, const std::unordered_map<std::string, std::vector<double>> &m,
384 const std::vector<double> &pars) const
385 {
386 ptr()->eval_batch_dbl(out, m, pars);
387 }
388
eval_num_dbl(const std::vector<double> & v) const389 double func::eval_num_dbl(const std::vector<double> &v) const
390 {
391 if (v.size() != args().size()) {
392 throw std::invalid_argument(
393 "Inconsistent number of arguments supplied to the double numerical evaluation of the function '{}': {} arguments were expected, but {} arguments were provided instead"_format(
394 get_name(), args().size(), v.size()));
395 }
396
397 return ptr()->eval_num_dbl(v);
398 }
399
deval_num_dbl(const std::vector<double> & v,std::vector<double>::size_type i) const400 double func::deval_num_dbl(const std::vector<double> &v, std::vector<double>::size_type i) const
401 {
402 if (v.size() != args().size()) {
403 throw std::invalid_argument(
404 "Inconsistent number of arguments supplied to the double numerical evaluation of the derivative of function '{}': {} arguments were expected, but {} arguments were provided instead"_format(
405 get_name(), args().size(), v.size()));
406 }
407
408 if (i >= v.size()) {
409 throw std::invalid_argument(
410 "Invalid index supplied to the double numerical evaluation of the derivative of function '{}': index {} was supplied, but the number of arguments is only {}"_format(
411 get_name(), args().size(), v.size()));
412 }
413
414 return ptr()->deval_num_dbl(v, i);
415 }
416
417 namespace detail
418 {
419
420 namespace
421 {
422
423 // Perform the decomposition of the arguments of a function. After this operation,
424 // each argument will be either:
425 // - a variable,
426 // - a number,
427 // - a param.
func_td_args(func & fb,std::unordered_map<const void *,taylor_dc_t::size_type> & func_map,taylor_dc_t & dc)428 void func_td_args(func &fb, std::unordered_map<const void *, taylor_dc_t::size_type> &func_map, taylor_dc_t &dc)
429 {
430 for (auto r = fb.get_mutable_args_it(); r.first != r.second; ++r.first) {
431 if (const auto dres = taylor_decompose(func_map, *r.first, dc)) {
432 *r.first = expression{variable{"u_{}"_format(dres)}};
433 }
434
435 assert(std::holds_alternative<variable>(r.first->value()) || std::holds_alternative<number>(r.first->value())
436 || std::holds_alternative<param>(r.first->value()));
437 }
438 }
439
440 } // namespace
441
442 } // namespace detail
443
taylor_decompose(std::unordered_map<const void *,taylor_dc_t::size_type> & func_map,taylor_dc_t & dc) const444 taylor_dc_t::size_type func::taylor_decompose(std::unordered_map<const void *, taylor_dc_t::size_type> &func_map,
445 taylor_dc_t &dc) const
446 {
447 const auto f_id = get_ptr();
448
449 if (auto it = func_map.find(f_id); it != func_map.end()) {
450 // We already decomposed the current function, fetch the result
451 // from the cache.
452 return it->second;
453 }
454
455 // Make a shallow copy: this will be a new function,
456 // but its arguments will be shallow-copied from this.
457 auto f_copy = copy();
458
459 // Decompose the arguments. This will overwrite
460 // the arguments in f_copy with their decomposition.
461 detail::func_td_args(f_copy, func_map, dc);
462
463 // Run the decomposition.
464 taylor_dc_t::size_type ret = 0;
465 if (f_copy.ptr()->has_taylor_decompose()) {
466 // Custom implementation.
467 ret = std::move(*f_copy.ptr()).taylor_decompose(dc);
468 } else {
469 // Default implementation: append f_copy and return the index
470 // at which it was appended.
471 dc.emplace_back(std::move(f_copy), std::vector<std::uint32_t>{});
472 ret = dc.size() - 1u;
473 }
474
475 if (ret == 0u) {
476 throw std::invalid_argument("The return value for the Taylor decomposition of a function can never be zero");
477 }
478
479 if (ret >= dc.size()) {
480 throw std::invalid_argument(
481 "Invalid value returned by the Taylor decomposition function for the function '{}': "
482 "the return value is {}, which is not less than the current size of the decomposition "
483 "({})"_format(get_name(), ret, dc.size()));
484 }
485
486 // Update the cache before exiting.
487 [[maybe_unused]] const auto [_, flag] = func_map.insert(std::pair{f_id, ret});
488 assert(flag);
489
490 return ret;
491 }
492
taylor_diff_dbl(llvm_state & s,const std::vector<std::uint32_t> & deps,const std::vector<llvm::Value * > & arr,llvm::Value * par_ptr,llvm::Value * time_ptr,std::uint32_t n_uvars,std::uint32_t order,std::uint32_t idx,std::uint32_t batch_size,bool high_accuracy) const493 llvm::Value *func::taylor_diff_dbl(llvm_state &s, const std::vector<std::uint32_t> &deps,
494 const std::vector<llvm::Value *> &arr, llvm::Value *par_ptr, llvm::Value *time_ptr,
495 std::uint32_t n_uvars, std::uint32_t order, std::uint32_t idx,
496 std::uint32_t batch_size, bool high_accuracy) const
497 {
498 if (par_ptr == nullptr) {
499 throw std::invalid_argument(
500 "Null par_ptr detected in func::taylor_diff_dbl() for the function '{}'"_format(get_name()));
501 }
502
503 if (time_ptr == nullptr) {
504 throw std::invalid_argument(
505 "Null time_ptr detected in func::taylor_diff_dbl() for the function '{}'"_format(get_name()));
506 }
507
508 if (batch_size == 0u) {
509 throw std::invalid_argument(
510 "Zero batch size detected in func::taylor_diff_dbl() for the function '{}'"_format(get_name()));
511 }
512
513 if (n_uvars == 0u) {
514 throw std::invalid_argument(
515 "Zero number of u variables detected in func::taylor_diff_dbl() for the function '{}'"_format(get_name()));
516 }
517
518 auto retval
519 = ptr()->taylor_diff_dbl(s, deps, arr, par_ptr, time_ptr, n_uvars, order, idx, batch_size, high_accuracy);
520
521 if (retval == nullptr) {
522 throw std::invalid_argument(
523 "Null return value detected in func::taylor_diff_dbl() for the function '{}'"_format(get_name()));
524 }
525
526 return retval;
527 }
528
taylor_diff_ldbl(llvm_state & s,const std::vector<std::uint32_t> & deps,const std::vector<llvm::Value * > & arr,llvm::Value * par_ptr,llvm::Value * time_ptr,std::uint32_t n_uvars,std::uint32_t order,std::uint32_t idx,std::uint32_t batch_size,bool high_accuracy) const529 llvm::Value *func::taylor_diff_ldbl(llvm_state &s, const std::vector<std::uint32_t> &deps,
530 const std::vector<llvm::Value *> &arr, llvm::Value *par_ptr, llvm::Value *time_ptr,
531 std::uint32_t n_uvars, std::uint32_t order, std::uint32_t idx,
532 std::uint32_t batch_size, bool high_accuracy) const
533 {
534 if (par_ptr == nullptr) {
535 throw std::invalid_argument(
536 "Null par_ptr detected in func::taylor_diff_ldbl() for the function '{}'"_format(get_name()));
537 }
538
539 if (time_ptr == nullptr) {
540 throw std::invalid_argument(
541 "Null time_ptr detected in func::taylor_diff_ldbl() for the function '{}'"_format(get_name()));
542 }
543
544 if (batch_size == 0u) {
545 throw std::invalid_argument(
546 "Zero batch size detected in func::taylor_diff_ldbl() for the function '{}'"_format(get_name()));
547 }
548
549 if (n_uvars == 0u) {
550 throw std::invalid_argument(
551 "Zero number of u variables detected in func::taylor_diff_ldbl() for the function '{}'"_format(get_name()));
552 }
553
554 auto retval
555 = ptr()->taylor_diff_ldbl(s, deps, arr, par_ptr, time_ptr, n_uvars, order, idx, batch_size, high_accuracy);
556
557 if (retval == nullptr) {
558 throw std::invalid_argument(
559 "Null return value detected in func::taylor_diff_ldbl() for the function '{}'"_format(get_name()));
560 }
561
562 return retval;
563 }
564
565 #if defined(HEYOKA_HAVE_REAL128)
566
taylor_diff_f128(llvm_state & s,const std::vector<std::uint32_t> & deps,const std::vector<llvm::Value * > & arr,llvm::Value * par_ptr,llvm::Value * time_ptr,std::uint32_t n_uvars,std::uint32_t order,std::uint32_t idx,std::uint32_t batch_size,bool high_accuracy) const567 llvm::Value *func::taylor_diff_f128(llvm_state &s, const std::vector<std::uint32_t> &deps,
568 const std::vector<llvm::Value *> &arr, llvm::Value *par_ptr, llvm::Value *time_ptr,
569 std::uint32_t n_uvars, std::uint32_t order, std::uint32_t idx,
570 std::uint32_t batch_size, bool high_accuracy) const
571 {
572 if (par_ptr == nullptr) {
573 throw std::invalid_argument(
574 "Null par_ptr detected in func::taylor_diff_f128() for the function '{}'"_format(get_name()));
575 }
576
577 if (time_ptr == nullptr) {
578 throw std::invalid_argument(
579 "Null time_ptr detected in func::taylor_diff_f128() for the function '{}'"_format(get_name()));
580 }
581
582 if (batch_size == 0u) {
583 throw std::invalid_argument(
584 "Zero batch size detected in func::taylor_diff_f128() for the function '{}'"_format(get_name()));
585 }
586
587 if (n_uvars == 0u) {
588 throw std::invalid_argument(
589 "Zero number of u variables detected in func::taylor_diff_f128() for the function '{}'"_format(get_name()));
590 }
591
592 auto retval
593 = ptr()->taylor_diff_f128(s, deps, arr, par_ptr, time_ptr, n_uvars, order, idx, batch_size, high_accuracy);
594
595 if (retval == nullptr) {
596 throw std::invalid_argument(
597 "Null return value detected in func::taylor_diff_f128() for the function '{}'"_format(get_name()));
598 }
599
600 return retval;
601 }
602
603 #endif
604
taylor_c_diff_func_dbl(llvm_state & s,std::uint32_t n_uvars,std::uint32_t batch_size,bool high_accuracy) const605 llvm::Function *func::taylor_c_diff_func_dbl(llvm_state &s, std::uint32_t n_uvars, std::uint32_t batch_size,
606 bool high_accuracy) const
607 {
608 if (batch_size == 0u) {
609 throw std::invalid_argument(
610 "Zero batch size detected in func::taylor_c_diff_func_dbl() for the function '{}'"_format(get_name()));
611 }
612
613 if (n_uvars == 0u) {
614 throw std::invalid_argument(
615 "Zero number of u variables detected in func::taylor_c_diff_func_dbl() for the function '{}'"_format(
616 get_name()));
617 }
618
619 auto retval = ptr()->taylor_c_diff_func_dbl(s, n_uvars, batch_size, high_accuracy);
620
621 if (retval == nullptr) {
622 throw std::invalid_argument(
623 "Null return value detected in func::taylor_c_diff_func_dbl() for the function '{}'"_format(get_name()));
624 }
625
626 return retval;
627 }
628
taylor_c_diff_func_ldbl(llvm_state & s,std::uint32_t n_uvars,std::uint32_t batch_size,bool high_accuracy) const629 llvm::Function *func::taylor_c_diff_func_ldbl(llvm_state &s, std::uint32_t n_uvars, std::uint32_t batch_size,
630 bool high_accuracy) const
631 {
632 if (batch_size == 0u) {
633 throw std::invalid_argument(
634 "Zero batch size detected in func::taylor_c_diff_func_ldbl() for the function '{}'"_format(get_name()));
635 }
636
637 if (n_uvars == 0u) {
638 throw std::invalid_argument(
639 "Zero number of u variables detected in func::taylor_c_diff_func_ldbl() for the function '{}'"_format(
640 get_name()));
641 }
642
643 auto retval = ptr()->taylor_c_diff_func_ldbl(s, n_uvars, batch_size, high_accuracy);
644
645 if (retval == nullptr) {
646 throw std::invalid_argument(
647 "Null return value detected in func::taylor_c_diff_func_ldbl() for the function '{}'"_format(get_name()));
648 }
649
650 return retval;
651 }
652
653 #if defined(HEYOKA_HAVE_REAL128)
654
taylor_c_diff_func_f128(llvm_state & s,std::uint32_t n_uvars,std::uint32_t batch_size,bool high_accuracy) const655 llvm::Function *func::taylor_c_diff_func_f128(llvm_state &s, std::uint32_t n_uvars, std::uint32_t batch_size,
656 bool high_accuracy) const
657 {
658 if (batch_size == 0u) {
659 throw std::invalid_argument(
660 "Zero batch size detected in func::taylor_c_diff_func_f128() for the function '{}'"_format(get_name()));
661 }
662
663 if (n_uvars == 0u) {
664 throw std::invalid_argument(
665 "Zero number of u variables detected in func::taylor_c_diff_func_f128() for the function '{}'"_format(
666 get_name()));
667 }
668
669 auto retval = ptr()->taylor_c_diff_func_f128(s, n_uvars, batch_size, high_accuracy);
670
671 if (retval == nullptr) {
672 throw std::invalid_argument(
673 "Null return value detected in func::taylor_c_diff_func_f128() for the function '{}'"_format(get_name()));
674 }
675
676 return retval;
677 }
678
679 #endif
680
swap(func & a,func & b)681 void swap(func &a, func &b) noexcept
682 {
683 std::swap(a.m_ptr, b.m_ptr);
684 }
685
operator <<(std::ostream & os,const func & f)686 std::ostream &operator<<(std::ostream &os, const func &f)
687 {
688 f.ptr()->to_stream(os);
689
690 return os;
691 }
692
hash(const func & f)693 std::size_t hash(const func &f)
694 {
695 // NOTE: the initial hash value is computed by combining the hash values of:
696 // - the function name,
697 // - the function inner type index,
698 // - the arguments' hashes.
699 std::size_t seed = std::hash<std::string>{}(f.get_name());
700
701 boost::hash_combine(seed, f.get_type_index());
702
703 for (const auto &arg : f.args()) {
704 boost::hash_combine(seed, hash(arg));
705 }
706
707 // Combine with the extra hash value too.
708 boost::hash_combine(seed, f.ptr()->extra_hash());
709
710 return seed;
711 }
712
operator ==(const func & a,const func & b)713 bool operator==(const func &a, const func &b)
714 {
715 // Check if the underlying object is the same.
716 if (a.m_ptr == b.m_ptr) {
717 return true;
718 }
719
720 // NOTE: the initial comparison considers:
721 // - the function name,
722 // - the function inner type index,
723 // - the arguments.
724 // If they are all equal, the extra equality comparison logic
725 // is also run.
726 if (a.get_name() == b.get_name() && a.get_type_index() == b.get_type_index() && a.args() == b.args()) {
727 return a.ptr()->extra_equal_to(b);
728 } else {
729 return false;
730 }
731 }
732
operator !=(const func & a,const func & b)733 bool operator!=(const func &a, const func &b)
734 {
735 return !(a == b);
736 }
737
eval_dbl(const func & f,const std::unordered_map<std::string,double> & map,const std::vector<double> & pars)738 double eval_dbl(const func &f, const std::unordered_map<std::string, double> &map, const std::vector<double> &pars)
739 {
740 return f.eval_dbl(map, pars);
741 }
742
eval_ldbl(const func & f,const std::unordered_map<std::string,long double> & map,const std::vector<long double> & pars)743 long double eval_ldbl(const func &f, const std::unordered_map<std::string, long double> &map,
744 const std::vector<long double> &pars)
745 {
746 return f.eval_ldbl(map, pars);
747 }
748
749 #if defined(HEYOKA_HAVE_REAL128)
750
eval_f128(const func & f,const std::unordered_map<std::string,mppp::real128> & map,const std::vector<mppp::real128> & pars)751 mppp::real128 eval_f128(const func &f, const std::unordered_map<std::string, mppp::real128> &map,
752 const std::vector<mppp::real128> &pars)
753 {
754 return f.eval_f128(map, pars);
755 }
756
757 #endif
758
eval_batch_dbl(std::vector<double> & out_values,const func & f,const std::unordered_map<std::string,std::vector<double>> & map,const std::vector<double> & pars)759 void eval_batch_dbl(std::vector<double> &out_values, const func &f,
760 const std::unordered_map<std::string, std::vector<double>> &map, const std::vector<double> &pars)
761 {
762 f.eval_batch_dbl(out_values, map, pars);
763 }
764
eval_num_dbl(const func & f,const std::vector<double> & in)765 double eval_num_dbl(const func &f, const std::vector<double> &in)
766 {
767 return f.eval_num_dbl(in);
768 }
769
deval_num_dbl(const func & f,const std::vector<double> & in,std::vector<double>::size_type d)770 double deval_num_dbl(const func &f, const std::vector<double> &in, std::vector<double>::size_type d)
771 {
772 return f.deval_num_dbl(in, d);
773 }
774
update_node_values_dbl(std::vector<double> & node_values,const func & f,const std::unordered_map<std::string,double> & map,const std::vector<std::vector<std::size_t>> & node_connections,std::size_t & node_counter)775 void update_node_values_dbl(std::vector<double> &node_values, const func &f,
776 const std::unordered_map<std::string, double> &map,
777 const std::vector<std::vector<std::size_t>> &node_connections, std::size_t &node_counter)
778 {
779 const auto node_id = node_counter;
780 node_counter++;
781 // We have to recurse first as to make sure node_values is filled before being accessed later.
782 for (decltype(f.args().size()) i = 0u; i < f.args().size(); ++i) {
783 update_node_values_dbl(node_values, f.args()[i], map, node_connections, node_counter);
784 }
785 // Then we compute
786 std::vector<double> in_values(f.args().size());
787 for (decltype(f.args().size()) i = 0u; i < f.args().size(); ++i) {
788 in_values[i] = node_values[node_connections[node_id][i]];
789 }
790 node_values[node_id] = eval_num_dbl(f, in_values);
791 }
792
update_grad_dbl(std::unordered_map<std::string,double> & grad,const func & f,const std::unordered_map<std::string,double> & map,const std::vector<double> & node_values,const std::vector<std::vector<std::size_t>> & node_connections,std::size_t & node_counter,double acc)793 void update_grad_dbl(std::unordered_map<std::string, double> &grad, const func &f,
794 const std::unordered_map<std::string, double> &map, const std::vector<double> &node_values,
795 const std::vector<std::vector<std::size_t>> &node_connections, std::size_t &node_counter,
796 double acc)
797 {
798 const auto node_id = node_counter;
799 node_counter++;
800 std::vector<double> in_values(f.args().size());
801 for (decltype(f.args().size()) i = 0u; i < f.args().size(); ++i) {
802 in_values[i] = node_values[node_connections[node_id][i]];
803 }
804 for (decltype(f.args().size()) i = 0u; i < f.args().size(); ++i) {
805 auto value = deval_num_dbl(f, in_values, i);
806 update_grad_dbl(grad, f.args()[i], map, node_values, node_connections, node_counter, acc * value);
807 }
808 }
809
update_connections(std::vector<std::vector<std::size_t>> & node_connections,const func & f,std::size_t & node_counter)810 void update_connections(std::vector<std::vector<std::size_t>> &node_connections, const func &f,
811 std::size_t &node_counter)
812 {
813 const auto node_id = node_counter;
814 node_counter++;
815 node_connections.push_back(std::vector<std::size_t>(f.args().size()));
816 for (decltype(f.args().size()) i = 0u; i < f.args().size(); ++i) {
817 node_connections[node_id][i] = node_counter;
818 update_connections(node_connections, f.args()[i], node_counter);
819 };
820 }
821
822 } // namespace heyoka
823