1 // Copyright 2020, 2021 Francesco Biscani (bluescarni@gmail.com), Dario Izzo (dario.izzo@gmail.com)
2 //
3 // This file is part of the heyoka library.
4 //
5 // This Source Code Form is subject to the terms of the Mozilla
6 // Public License v. 2.0. If a copy of the MPL was not distributed
7 // with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
8 
9 #include <heyoka/config.hpp>
10 
11 #include <cmath>
12 #include <cstddef>
13 #include <cstdint>
14 #include <limits>
15 #include <ostream>
16 #include <set>
17 #include <stdexcept>
18 #include <string>
19 #include <type_traits>
20 #include <unordered_map>
21 #include <unordered_set>
22 #include <utility>
23 #include <variant>
24 #include <vector>
25 
26 #include <tbb/blocked_range.h>
27 #include <tbb/parallel_for.h>
28 
29 #include <llvm/IR/Function.h>
30 #include <llvm/IR/Value.h>
31 
32 #include <fmt/format.h>
33 #include <fmt/ostream.h>
34 
35 #if defined(HEYOKA_HAVE_REAL128)
36 
37 #include <mp++/real128.hpp>
38 
39 #endif
40 
41 #include <heyoka/detail/llvm_fwd.hpp>
42 #include <heyoka/detail/type_traits.hpp>
43 #include <heyoka/exceptions.hpp>
44 #include <heyoka/expression.hpp>
45 #include <heyoka/func.hpp>
46 #include <heyoka/llvm_state.hpp>
47 #include <heyoka/math/binary_op.hpp>
48 #include <heyoka/math/neg.hpp>
49 #include <heyoka/math/square.hpp>
50 #include <heyoka/math/sum.hpp>
51 #include <heyoka/math/time.hpp>
52 #include <heyoka/math/tpoly.hpp>
53 #include <heyoka/number.hpp>
54 #include <heyoka/param.hpp>
55 #include <heyoka/variable.hpp>
56 
57 #if defined(_MSC_VER) && !defined(__clang__)
58 
59 // NOTE: MSVC has issues with the other "using"
60 // statement form.
61 using namespace fmt::literals;
62 
63 #else
64 
65 using fmt::literals::operator""_format;
66 
67 #endif
68 
69 namespace heyoka
70 {
71 
expression()72 expression::expression() : expression(number{0.}) {}
73 
expression(double x)74 expression::expression(double x) : expression(number{x}) {}
75 
expression(long double x)76 expression::expression(long double x) : expression(number{x}) {}
77 
78 #if defined(HEYOKA_HAVE_REAL128)
79 
expression(mppp::real128 x)80 expression::expression(mppp::real128 x) : expression(number{x}) {}
81 
82 #endif
83 
expression(std::string s)84 expression::expression(std::string s) : expression(variable{std::move(s)}) {}
85 
expression(number n)86 expression::expression(number n) : m_value(std::move(n)) {}
87 
expression(variable var)88 expression::expression(variable var) : m_value(std::move(var)) {}
89 
expression(func f)90 expression::expression(func f) : m_value(std::move(f)) {}
91 
expression(param p)92 expression::expression(param p) : m_value(std::move(p)) {}
93 
94 expression::expression(const expression &) = default;
95 
96 expression::expression(expression &&) noexcept = default;
97 
98 expression::~expression() = default;
99 
100 expression &expression::operator=(const expression &) = default;
101 
102 expression &expression::operator=(expression &&) noexcept = default;
103 
value()104 expression::value_type &expression::value()
105 {
106     return m_value;
107 }
108 
value() const109 const expression::value_type &expression::value() const
110 {
111     return m_value;
112 }
113 
114 namespace detail
115 {
116 
117 namespace
118 {
119 
copy(std::unordered_map<const void *,expression> & func_map,const expression & e)120 expression copy(std::unordered_map<const void *, expression> &func_map, const expression &e)
121 {
122     return std::visit(
123         [&func_map](const auto &v) {
124             if constexpr (std::is_same_v<detail::uncvref_t<decltype(v)>, func>) {
125                 const auto f_id = v.get_ptr();
126 
127                 if (auto it = func_map.find(f_id); it != func_map.end()) {
128                     // We already copied the current function, fetch the copy
129                     // from the cache.
130                     return it->second;
131                 }
132 
133                 // Create a copy of v. Note that this will copy
134                 // the arguments of v via their copy constructor,
135                 // and thus any argument which is itself a function
136                 // will be shallow-copied.
137                 auto f_copy = v.copy();
138 
139                 // Perform a copy of the arguments of v which are functions.
140                 assert(v.args().size() == f_copy.args().size()); // LCOV_EXCL_LINE
141                 auto b1 = v.args().begin();
142                 for (auto [b2, e2] = f_copy.get_mutable_args_it(); b2 != e2; ++b1, ++b2) {
143                     // NOTE: the argument needs to be copied via a recursive
144                     // call to copy() only if it is a func. Otherwise, the copy
145                     // we made earlier via the copy constructor is already a deep copy.
146                     if (std::holds_alternative<func>(b1->value())) {
147                         *b2 = copy(func_map, *b1);
148                     }
149                 }
150 
151                 // Construct the return value and put it into the cache.
152                 auto ex = expression{std::move(f_copy)};
153                 [[maybe_unused]] const auto [_, flag] = func_map.insert(std::pair{f_id, ex});
154                 // NOTE: an expression cannot contain itself.
155                 assert(flag); // LCOV_EXCL_LINE
156 
157                 return ex;
158             } else {
159                 return expression{v};
160             }
161         },
162         e.value());
163 }
164 
165 } // namespace
166 
167 } // namespace detail
168 
copy(const expression & e)169 expression copy(const expression &e)
170 {
171     std::unordered_map<const void *, expression> func_map;
172 
173     return detail::copy(func_map, e);
174 }
175 
176 inline namespace literals
177 {
178 
operator ""_dbl(long double x)179 expression operator""_dbl(long double x)
180 {
181     return expression{static_cast<double>(x)};
182 }
183 
operator ""_dbl(unsigned long long n)184 expression operator""_dbl(unsigned long long n)
185 {
186     return expression{static_cast<double>(n)};
187 }
188 
operator ""_ldbl(long double x)189 expression operator""_ldbl(long double x)
190 {
191     return expression{x};
192 }
193 
operator ""_ldbl(unsigned long long n)194 expression operator""_ldbl(unsigned long long n)
195 {
196     return expression{static_cast<long double>(n)};
197 }
198 
operator ""_var(const char * s,std::size_t n)199 expression operator""_var(const char *s, std::size_t n)
200 {
201     return expression{variable{std::string(s, n)}};
202 }
203 
204 } // namespace literals
205 
206 namespace detail
207 {
208 
prime_wrapper(std::string s)209 prime_wrapper::prime_wrapper(std::string s) : m_str(std::move(s)) {}
210 
211 prime_wrapper::prime_wrapper(const prime_wrapper &) = default;
212 
213 prime_wrapper::prime_wrapper(prime_wrapper &&) noexcept = default;
214 
215 prime_wrapper &prime_wrapper::operator=(const prime_wrapper &) = default;
216 
217 prime_wrapper &prime_wrapper::operator=(prime_wrapper &&) noexcept = default;
218 
219 prime_wrapper::~prime_wrapper() = default;
220 
operator =(expression e)221 std::pair<expression, expression> prime_wrapper::operator=(expression e) &&
222 {
223     return std::pair{expression{variable{std::move(m_str)}}, std::move(e)};
224 }
225 
226 } // namespace detail
227 
prime(expression e)228 detail::prime_wrapper prime(expression e)
229 {
230     return std::visit(
231         [&e](auto &v) -> detail::prime_wrapper {
232             if constexpr (std::is_same_v<variable, detail::uncvref_t<decltype(v)>>) {
233                 return detail::prime_wrapper{std::move(v.name())};
234             } else {
235                 throw std::invalid_argument(
236                     "Cannot apply the prime() operator to the non-variable expression '{}'"_format(e));
237             }
238         },
239         e.value());
240 }
241 
242 inline namespace literals
243 {
244 
operator ""_p(const char * s,std::size_t n)245 detail::prime_wrapper operator""_p(const char *s, std::size_t n)
246 {
247     return detail::prime_wrapper{std::string(s, n)};
248 }
249 
250 } // namespace literals
251 
252 namespace detail
253 {
254 
255 namespace
256 {
257 
get_variables(std::unordered_set<const void * > & func_set,std::set<std::string> & s_set,const expression & e)258 void get_variables(std::unordered_set<const void *> &func_set, std::set<std::string> &s_set, const expression &e)
259 {
260     std::visit(
261         [&func_set, &s_set](const auto &arg) {
262             using type = detail::uncvref_t<decltype(arg)>;
263 
264             if constexpr (std::is_same_v<type, func>) {
265                 const auto f_id = arg.get_ptr();
266 
267                 if (func_set.find(f_id) != func_set.end()) {
268                     // We already determined the list of variables for the
269                     // current function, exit.
270                     return;
271                 }
272 
273                 // Determine the list of variables for each
274                 // function argument.
275                 for (const auto &farg : arg.args()) {
276                     get_variables(func_set, s_set, farg);
277                 }
278 
279                 // Add the id of f to the set.
280                 [[maybe_unused]] const auto [_, flag] = func_set.insert(f_id);
281                 // NOTE: an expression cannot contain itself.
282                 assert(flag);
283             } else if constexpr (std::is_same_v<type, variable>) {
284                 s_set.insert(arg.name());
285             }
286         },
287         e.value());
288 }
289 
rename_variables(std::unordered_set<const void * > & func_set,expression & e,const std::unordered_map<std::string,std::string> & repl_map)290 void rename_variables(std::unordered_set<const void *> &func_set, expression &e,
291                       const std::unordered_map<std::string, std::string> &repl_map)
292 {
293     std::visit(
294         [&func_set, &repl_map](auto &arg) {
295             using type = detail::uncvref_t<decltype(arg)>;
296 
297             if constexpr (std::is_same_v<type, func>) {
298                 const auto f_id = arg.get_ptr();
299 
300                 if (func_set.find(f_id) != func_set.end()) {
301                     // We already renamed variables for the current function,
302                     // just return.
303                     return;
304                 }
305 
306                 for (auto [b, e] = arg.get_mutable_args_it(); b != e; ++b) {
307                     rename_variables(func_set, *b, repl_map);
308                 }
309 
310                 // Add the id of f to the set.
311                 [[maybe_unused]] const auto [_, flag] = func_set.insert(f_id);
312                 // NOTE: an expression cannot contain itself.
313                 assert(flag);
314             } else if constexpr (std::is_same_v<type, variable>) {
315                 if (auto it = repl_map.find(arg.name()); it != repl_map.end()) {
316                     arg.name() = it->second;
317                 }
318             }
319         },
320         e.value());
321 }
322 
323 } // namespace
324 
325 } // namespace detail
326 
get_variables(const expression & e)327 std::vector<std::string> get_variables(const expression &e)
328 {
329     std::unordered_set<const void *> func_set;
330     std::set<std::string> s_set;
331 
332     detail::get_variables(func_set, s_set, e);
333 
334     return std::vector<std::string>(s_set.begin(), s_set.end());
335 }
336 
rename_variables(expression & e,const std::unordered_map<std::string,std::string> & repl_map)337 void rename_variables(expression &e, const std::unordered_map<std::string, std::string> &repl_map)
338 {
339     std::unordered_set<const void *> func_set;
340 
341     detail::rename_variables(func_set, e, repl_map);
342 }
343 
swap(expression & ex0,expression & ex1)344 void swap(expression &ex0, expression &ex1) noexcept
345 {
346     std::swap(ex0.value(), ex1.value());
347 }
348 
349 // NOTE: this implementation does not take advantage of potentially
350 // repeating subexpressions. This is not currently a problem because
351 // hashing is needed only in the CSE for the decomposition, which involves
352 // only trivial expressions. However, this would likely be needed by a to_sympy()
353 // implementation in heyoka.py which allows for a dictionary of custom
354 // substitutions to be provided by the user.
hash(const expression & ex)355 std::size_t hash(const expression &ex)
356 {
357     return std::visit([](const auto &v) { return hash(v); }, ex.value());
358 }
359 
operator <<(std::ostream & os,const expression & e)360 std::ostream &operator<<(std::ostream &os, const expression &e)
361 {
362     return std::visit([&os](const auto &arg) -> std::ostream & { return os << arg; }, e.value());
363 }
364 
operator +(expression e)365 expression operator+(expression e)
366 {
367     return e;
368 }
369 
operator -(expression e)370 expression operator-(expression e)
371 {
372     if (auto num_ptr = std::get_if<number>(&e.value())) {
373         // Simplify -number to its numerical value.
374         return expression{-std::move(*num_ptr)};
375     } else {
376         if (auto fptr = detail::is_neg(e)) {
377             // Simplify -(-x) to x.
378             assert(!fptr->args().empty()); // LCOV_EXCL_LINE
379             return fptr->args()[0];
380         } else {
381             return neg(std::move(e));
382         }
383     }
384 }
385 
operator +(expression e1,expression e2)386 expression operator+(expression e1, expression e2)
387 {
388     // Simplify x + neg(y) to x - y.
389     if (auto fptr = detail::is_neg(e2)) {
390         assert(!fptr->args().empty()); // LCOV_EXCL_LINE
391         return std::move(e1) - fptr->args()[0];
392     }
393 
394     auto visitor = [](auto &&v1, auto &&v2) {
395         using type1 = detail::uncvref_t<decltype(v1)>;
396         using type2 = detail::uncvref_t<decltype(v2)>;
397 
398         if constexpr (std::is_same_v<type1, number> && std::is_same_v<type2, number>) {
399             // Both are numbers, add them and return the result.
400             return expression{std::forward<decltype(v1)>(v1) + std::forward<decltype(v2)>(v2)};
401         } else if constexpr (std::is_same_v<type1, number>) {
402             // e1 number, e2 symbolic.
403             if (is_zero(v1)) {
404                 // 0 + e2 = e2.
405                 return expression{std::forward<decltype(v2)>(v2)};
406             }
407             if constexpr (std::is_same_v<func, type2>) {
408                 if (auto pbop = v2.template extract<detail::binary_op>();
409                     pbop != nullptr && pbop->op() == detail::binary_op::type::add
410                     && std::holds_alternative<number>(pbop->args()[0].value())) {
411                     // e2 = a + x, where a is a number. Simplify e1 + (a + x) -> c + x, where c = e1 + a.
412                     return expression{std::forward<decltype(v1)>(v1)} + pbop->args()[0] + pbop->args()[1];
413                 }
414             }
415 
416             // NOTE: fall through the standard case.
417         } else if constexpr (std::is_same_v<type2, number>) {
418             // e1 symbolic, e2 number. Swap the operands so that the number comes first.
419             return expression{std::forward<decltype(v2)>(v2)} + expression{std::forward<decltype(v1)>(v1)};
420         }
421 
422         // The standard case.
423         return add(expression{std::forward<decltype(v1)>(v1)}, expression{std::forward<decltype(v2)>(v2)});
424     };
425 
426     return std::visit(visitor, std::move(e1.value()), std::move(e2.value()));
427 }
428 
operator -(expression e1,expression e2)429 expression operator-(expression e1, expression e2)
430 {
431     // Simplify x - (-y) to x + y.
432     if (auto fptr = detail::is_neg(e2)) {
433         assert(!fptr->args().empty()); // LCOV_EXCL_LINE
434         return std::move(e1) + fptr->args()[0];
435     }
436 
437     auto visitor = [](auto &&v1, auto &&v2) {
438         using type1 = detail::uncvref_t<decltype(v1)>;
439         using type2 = detail::uncvref_t<decltype(v2)>;
440 
441         if constexpr (std::is_same_v<type1, number> && std::is_same_v<type2, number>) {
442             // Both are numbers, subtract them.
443             return expression{std::forward<decltype(v1)>(v1) - std::forward<decltype(v2)>(v2)};
444         } else if constexpr (std::is_same_v<type1, number>) {
445             // e1 number, e2 symbolic.
446             if (is_zero(v1)) {
447                 // 0 - e2 = -e2.
448                 return -expression{std::forward<decltype(v2)>(v2)};
449             }
450             // NOTE: fall through the standard case if e1 is not zero.
451         } else if constexpr (std::is_same_v<type2, number>) {
452             // e1 symbolic, e2 number. Turn e1 - e2 into e1 + (-e2),
453             // because addition provides more simplification capabilities.
454             return expression{std::forward<decltype(v1)>(v1)} + expression{-std::forward<decltype(v2)>(v2)};
455         }
456 
457         // The standard case.
458         return sub(expression{std::forward<decltype(v1)>(v1)}, expression{std::forward<decltype(v2)>(v2)});
459     };
460 
461     return std::visit(visitor, std::move(e1.value()), std::move(e2.value()));
462 }
463 
operator *(expression e1,expression e2)464 expression operator*(expression e1, expression e2)
465 {
466     auto fptr1 = detail::is_neg(e1);
467     auto fptr2 = detail::is_neg(e2);
468 
469     if (fptr1 != nullptr && fptr2 != nullptr) {
470         // Simplify (-x) * (-y) into x*y.
471         assert(!fptr1->args().empty()); // LCOV_EXCL_LINE
472         assert(!fptr2->args().empty()); // LCOV_EXCL_LINE
473         return fptr1->args()[0] * fptr2->args()[0];
474     }
475 
476     // Simplify x*x -> square(x) if x is not a number (otherwise,
477     // we will numerically compute the result below).
478     if (e1 == e2 && !std::holds_alternative<number>(e1.value())) {
479         return square(std::move(e1));
480     }
481 
482     auto visitor = [fptr2](auto &&v1, auto &&v2) {
483         using type1 = detail::uncvref_t<decltype(v1)>;
484         using type2 = detail::uncvref_t<decltype(v2)>;
485 
486         if constexpr (std::is_same_v<type1, number> && std::is_same_v<type2, number>) {
487             // Both are numbers, multiply them.
488             return expression{std::forward<decltype(v1)>(v1) * std::forward<decltype(v2)>(v2)};
489         } else if constexpr (std::is_same_v<type1, number>) {
490             // e1 number, e2 symbolic.
491             if (is_zero(v1)) {
492                 // 0 * e2 = 0.
493                 return expression{number{0.}};
494             }
495             if (is_one(v1)) {
496                 // 1 * e2 = e2.
497                 return expression{std::forward<decltype(v2)>(v2)};
498             }
499             if (is_negative_one(v1)) {
500                 // -1 * e2 = -e2.
501                 return -expression{std::forward<decltype(v2)>(v2)};
502             }
503             if (fptr2 != nullptr) {
504                 // a * (-x) = (-a) * x.
505                 assert(!fptr2->args().empty()); // LCOV_EXCL_LINE
506                 return expression{-std::forward<decltype(v1)>(v1)} * fptr2->args()[0];
507             }
508             if constexpr (std::is_same_v<func, type2>) {
509                 if (auto pbop = v2.template extract<detail::binary_op>();
510                     pbop != nullptr && pbop->op() == detail::binary_op::type::mul
511                     && std::holds_alternative<number>(pbop->args()[0].value())) {
512                     // e2 = a * x, where a is a number. Simplify e1 * (a * x) -> c * x, where c = e1 * a.
513                     return expression{std::forward<decltype(v1)>(v1)} * pbop->args()[0] * pbop->args()[1];
514                 }
515             }
516 
517             // NOTE: fall through the standard case.
518         } else if constexpr (std::is_same_v<type2, number>) {
519             // e1 symbolic, e2 number. Swap the operands so that the number comes first.
520             return expression{std::forward<decltype(v2)>(v2)} * expression{std::forward<decltype(v1)>(v1)};
521         }
522 
523         // The standard case.
524         return mul(expression{std::forward<decltype(v1)>(v1)}, expression{std::forward<decltype(v2)>(v2)});
525     };
526 
527     return std::visit(visitor, std::move(e1.value()), std::move(e2.value()));
528 }
529 
operator /(expression e1,expression e2)530 expression operator/(expression e1, expression e2)
531 {
532     auto fptr1 = detail::is_neg(e1);
533     auto fptr2 = detail::is_neg(e2);
534 
535     if (fptr1 != nullptr && fptr2 != nullptr) {
536         // Simplify (-x) / (-y) into x/y.
537         assert(!fptr1->args().empty()); // LCOV_EXCL_LINE
538         assert(!fptr2->args().empty()); // LCOV_EXCL_LINE
539         return fptr1->args()[0] / fptr2->args()[0];
540     }
541 
542     auto visitor = [fptr1, fptr2](auto &&v1, auto &&v2) {
543         using type1 = detail::uncvref_t<decltype(v1)>;
544         using type2 = detail::uncvref_t<decltype(v2)>;
545 
546         if constexpr (std::is_same_v<type2, number>) {
547             // If the divisor is zero, always raise an error.
548             if (is_zero(v2)) {
549                 throw zero_division_error("Division by zero");
550             }
551         }
552 
553         if constexpr (std::is_same_v<type1, number> && std::is_same_v<type2, number>) {
554             // Both are numbers, divide them.
555             return expression{std::forward<decltype(v1)>(v1) / std::forward<decltype(v2)>(v2)};
556         } else if constexpr (std::is_same_v<type2, number>) {
557             // e1 is symbolic, e2 a number.
558             if (is_one(v2)) {
559                 // e1 / 1 = e1.
560                 return expression{std::forward<decltype(v1)>(v1)};
561             }
562             if (is_negative_one(v2)) {
563                 // e1 / -1 = -e1.
564                 return -expression{std::forward<decltype(v1)>(v1)};
565             }
566             if (fptr1 != nullptr) {
567                 // (-e1) / a = e1 / (-a).
568                 assert(!fptr1->args().empty()); // LCOV_EXCL_LINE
569                 return fptr1->args()[0] / expression{-std::forward<decltype(v2)>(v2)};
570             }
571             if constexpr (std::is_same_v<func, type1>) {
572                 if (auto pbop = v1.template extract<detail::binary_op>();
573                     pbop != nullptr && pbop->op() == detail::binary_op::type::div
574                     && std::holds_alternative<number>(pbop->args()[1].value())) {
575                     // e1 = x / a, where a is a number. Simplify (x / a) / b -> x / (a * b).
576                     return pbop->args()[0] / (pbop->args()[1] * expression{std::forward<decltype(v2)>(v2)});
577                 }
578             }
579 
580             // NOTE: fall through to the standard case.
581         } else if constexpr (std::is_same_v<type1, number>) {
582             // e1 is a number, e2 is symbolic.
583             if (is_zero(v1)) {
584                 // 0 / e2 == 0.
585                 return expression{number{0.}};
586             }
587             if (fptr2 != nullptr) {
588                 // a / (-e2) = (-a) / e2.
589                 assert(!fptr2->args().empty()); // LCOV_EXCL_LINE
590                 return expression{-std::forward<decltype(v1)>(v1)} / fptr2->args()[0];
591             }
592 
593             // NOTE: fall through to the standard case.
594         }
595 
596         // The standard case.
597         return div(expression{std::forward<decltype(v1)>(v1)}, expression{std::forward<decltype(v2)>(v2)});
598     };
599 
600     return std::visit(visitor, std::move(e1.value()), std::move(e2.value()));
601 }
602 
operator +(expression ex,double x)603 expression operator+(expression ex, double x)
604 {
605     return std::move(ex) + expression{x};
606 }
607 
operator +(expression ex,long double x)608 expression operator+(expression ex, long double x)
609 {
610     return std::move(ex) + expression{x};
611 }
612 
613 #if defined(HEYOKA_HAVE_REAL128)
614 
operator +(expression ex,mppp::real128 x)615 expression operator+(expression ex, mppp::real128 x)
616 {
617     return std::move(ex) + expression{x};
618 }
619 
620 #endif
621 
operator +(double x,expression ex)622 expression operator+(double x, expression ex)
623 {
624     return expression{x} + std::move(ex);
625 }
626 
operator +(long double x,expression ex)627 expression operator+(long double x, expression ex)
628 {
629     return expression{x} + std::move(ex);
630 }
631 
632 #if defined(HEYOKA_HAVE_REAL128)
633 
operator +(mppp::real128 x,expression ex)634 expression operator+(mppp::real128 x, expression ex)
635 {
636     return expression{x} + std::move(ex);
637 }
638 
639 #endif
640 
operator -(expression ex,double x)641 expression operator-(expression ex, double x)
642 {
643     return std::move(ex) - expression{x};
644 }
645 
operator -(expression ex,long double x)646 expression operator-(expression ex, long double x)
647 {
648     return std::move(ex) - expression{x};
649 }
650 
651 #if defined(HEYOKA_HAVE_REAL128)
652 
operator -(expression ex,mppp::real128 x)653 expression operator-(expression ex, mppp::real128 x)
654 {
655     return std::move(ex) - expression{x};
656 }
657 
658 #endif
659 
operator -(double x,expression ex)660 expression operator-(double x, expression ex)
661 {
662     return expression{x} - std::move(ex);
663 }
664 
operator -(long double x,expression ex)665 expression operator-(long double x, expression ex)
666 {
667     return expression{x} - std::move(ex);
668 }
669 
670 #if defined(HEYOKA_HAVE_REAL128)
671 
operator -(mppp::real128 x,expression ex)672 expression operator-(mppp::real128 x, expression ex)
673 {
674     return expression{x} - std::move(ex);
675 }
676 
677 #endif
678 
operator *(expression ex,double x)679 expression operator*(expression ex, double x)
680 {
681     return std::move(ex) * expression{x};
682 }
683 
operator *(expression ex,long double x)684 expression operator*(expression ex, long double x)
685 {
686     return std::move(ex) * expression{x};
687 }
688 
689 #if defined(HEYOKA_HAVE_REAL128)
690 
operator *(expression ex,mppp::real128 x)691 expression operator*(expression ex, mppp::real128 x)
692 {
693     return std::move(ex) * expression{x};
694 }
695 
696 #endif
697 
operator *(double x,expression ex)698 expression operator*(double x, expression ex)
699 {
700     return expression{x} * std::move(ex);
701 }
702 
operator *(long double x,expression ex)703 expression operator*(long double x, expression ex)
704 {
705     return expression{x} * std::move(ex);
706 }
707 
708 #if defined(HEYOKA_HAVE_REAL128)
709 
operator *(mppp::real128 x,expression ex)710 expression operator*(mppp::real128 x, expression ex)
711 {
712     return expression{x} * std::move(ex);
713 }
714 
715 #endif
716 
operator /(expression ex,double x)717 expression operator/(expression ex, double x)
718 {
719     return std::move(ex) / expression{x};
720 }
721 
operator /(expression ex,long double x)722 expression operator/(expression ex, long double x)
723 {
724     return std::move(ex) / expression{x};
725 }
726 
727 #if defined(HEYOKA_HAVE_REAL128)
728 
operator /(expression ex,mppp::real128 x)729 expression operator/(expression ex, mppp::real128 x)
730 {
731     return std::move(ex) / expression{x};
732 }
733 
734 #endif
735 
operator /(double x,expression ex)736 expression operator/(double x, expression ex)
737 {
738     return expression{x} / std::move(ex);
739 }
740 
operator /(long double x,expression ex)741 expression operator/(long double x, expression ex)
742 {
743     return expression{x} / std::move(ex);
744 }
745 
746 #if defined(HEYOKA_HAVE_REAL128)
747 
operator /(mppp::real128 x,expression ex)748 expression operator/(mppp::real128 x, expression ex)
749 {
750     return expression{x} / std::move(ex);
751 }
752 
753 #endif
754 
operator +=(expression & x,expression e)755 expression &operator+=(expression &x, expression e)
756 {
757     return x = std::move(x) + std::move(e);
758 }
759 
operator -=(expression & x,expression e)760 expression &operator-=(expression &x, expression e)
761 {
762     return x = std::move(x) - std::move(e);
763 }
764 
operator *=(expression & x,expression e)765 expression &operator*=(expression &x, expression e)
766 {
767     return x = std::move(x) * std::move(e);
768 }
769 
operator /=(expression & x,expression e)770 expression &operator/=(expression &x, expression e)
771 {
772     return x = std::move(x) / std::move(e);
773 }
774 
operator +=(expression & ex,double x)775 expression &operator+=(expression &ex, double x)
776 {
777     return ex += expression{x};
778 }
779 
operator +=(expression & ex,long double x)780 expression &operator+=(expression &ex, long double x)
781 {
782     return ex += expression{x};
783 }
784 
785 #if defined(HEYOKA_HAVE_REAL128)
786 
operator +=(expression & ex,mppp::real128 x)787 expression &operator+=(expression &ex, mppp::real128 x)
788 {
789     return ex += expression{x};
790 }
791 
792 #endif
793 
operator -=(expression & ex,double x)794 expression &operator-=(expression &ex, double x)
795 {
796     return ex -= expression{x};
797 }
798 
operator -=(expression & ex,long double x)799 expression &operator-=(expression &ex, long double x)
800 {
801     return ex -= expression{x};
802 }
803 
804 #if defined(HEYOKA_HAVE_REAL128)
805 
operator -=(expression & ex,mppp::real128 x)806 expression &operator-=(expression &ex, mppp::real128 x)
807 {
808     return ex -= expression{x};
809 }
810 
811 #endif
812 
operator *=(expression & ex,double x)813 expression &operator*=(expression &ex, double x)
814 {
815     return ex *= expression{x};
816 }
817 
operator *=(expression & ex,long double x)818 expression &operator*=(expression &ex, long double x)
819 {
820     return ex *= expression{x};
821 }
822 
823 #if defined(HEYOKA_HAVE_REAL128)
824 
operator *=(expression & ex,mppp::real128 x)825 expression &operator*=(expression &ex, mppp::real128 x)
826 {
827     return ex *= expression{x};
828 }
829 
830 #endif
831 
operator /=(expression & ex,double x)832 expression &operator/=(expression &ex, double x)
833 {
834     return ex /= expression{x};
835 }
836 
operator /=(expression & ex,long double x)837 expression &operator/=(expression &ex, long double x)
838 {
839     return ex /= expression{x};
840 }
841 
842 #if defined(HEYOKA_HAVE_REAL128)
843 
operator /=(expression & ex,mppp::real128 x)844 expression &operator/=(expression &ex, mppp::real128 x)
845 {
846     return ex /= expression{x};
847 }
848 
849 #endif
850 
operator ==(const expression & e1,const expression & e2)851 bool operator==(const expression &e1, const expression &e2)
852 {
853     auto visitor = [](const auto &v1, const auto &v2) {
854         using type1 = detail::uncvref_t<decltype(v1)>;
855         using type2 = detail::uncvref_t<decltype(v2)>;
856 
857         if constexpr (std::is_same_v<type1, type2>) {
858             return v1 == v2;
859         } else {
860             return false;
861         }
862     };
863 
864     return std::visit(visitor, e1.value(), e2.value());
865 }
866 
operator !=(const expression & e1,const expression & e2)867 bool operator!=(const expression &e1, const expression &e2)
868 {
869     return !(e1 == e2);
870 }
871 
872 namespace detail
873 {
874 
875 namespace
876 {
877 
get_n_nodes(std::unordered_map<const void *,std::size_t> & func_map,const expression & e)878 std::size_t get_n_nodes(std::unordered_map<const void *, std::size_t> &func_map, const expression &e)
879 {
880     return std::visit(
881         [&func_map](const auto &arg) -> std::size_t {
882             if constexpr (std::is_same_v<func, detail::uncvref_t<decltype(arg)>>) {
883                 const auto f_id = arg.get_ptr();
884 
885                 if (auto it = func_map.find(f_id); it != func_map.end()) {
886                     // We already computed the number of nodes for the current
887                     // function, return it.
888                     return it->second;
889                 }
890 
891                 std::size_t retval = 1;
892                 for (const auto &ex : arg.args()) {
893                     retval += get_n_nodes(func_map, ex);
894                 }
895 
896                 // Store the number of nodes for the current function
897                 // in the cache.
898                 [[maybe_unused]] const auto [_, flag] = func_map.insert(std::pair{f_id, retval});
899                 // NOTE: an expression cannot contain itself.
900                 assert(flag);
901 
902                 return retval;
903             } else {
904                 return 1;
905             }
906         },
907         e.value());
908 }
909 
910 } // namespace
911 
912 } // namespace detail
913 
get_n_nodes(const expression & e)914 std::size_t get_n_nodes(const expression &e)
915 {
916     std::unordered_map<const void *, std::size_t> func_map;
917 
918     return detail::get_n_nodes(func_map, e);
919 }
920 
921 namespace detail
922 {
923 
diff(std::unordered_map<const void *,expression> & func_map,const expression & e,const std::string & s)924 expression diff(std::unordered_map<const void *, expression> &func_map, const expression &e, const std::string &s)
925 {
926     return std::visit(
927         [&func_map, &s](const auto &arg) {
928             using type = detail::uncvref_t<decltype(arg)>;
929 
930             if constexpr (std::is_same_v<type, number>) {
931                 return std::visit([](const auto &v) { return expression{number{detail::uncvref_t<decltype(v)>(0)}}; },
932                                   arg.value());
933             } else if constexpr (std::is_same_v<type, param>) {
934                 // NOTE: if we ever implement single-precision support,
935                 // this should be probably changed into 0_flt (i.e., the lowest
936                 // precision numerical type), so that it does not trigger
937                 // type promotions in numerical constants. Other similar
938                 // occurrences as well (e.g., diff for variable).
939                 return 0_dbl;
940             } else if constexpr (std::is_same_v<type, variable>) {
941                 if (s == arg.name()) {
942                     return 1_dbl;
943                 } else {
944                     return 0_dbl;
945                 }
946             } else {
947                 const auto f_id = arg.get_ptr();
948 
949                 if (auto it = func_map.find(f_id); it != func_map.end()) {
950                     // We already performed diff on the current function,
951                     // fetch the result from the cache.
952                     return it->second;
953                 }
954 
955                 auto ret = arg.diff(func_map, s);
956 
957                 // Put the return value in the cache.
958                 [[maybe_unused]] const auto [_, flag] = func_map.insert(std::pair{f_id, ret});
959                 // NOTE: an expression cannot contain itself.
960                 assert(flag);
961 
962                 return ret;
963             }
964         },
965         e.value());
966 }
967 
diff(std::unordered_map<const void *,expression> & func_map,const expression & e,const param & p)968 expression diff(std::unordered_map<const void *, expression> &func_map, const expression &e, const param &p)
969 {
970     return std::visit(
971         [&func_map, &p](const auto &arg) {
972             using type = detail::uncvref_t<decltype(arg)>;
973 
974             if constexpr (std::is_same_v<type, number>) {
975                 return std::visit([](const auto &v) { return expression{number{detail::uncvref_t<decltype(v)>(0)}}; },
976                                   arg.value());
977             } else if constexpr (std::is_same_v<type, param>) {
978                 if (p.idx() == arg.idx()) {
979                     return 1_dbl;
980                 } else {
981                     return 0_dbl;
982                 }
983             } else if constexpr (std::is_same_v<type, variable>) {
984                 return 0_dbl;
985             } else {
986                 const auto f_id = arg.get_ptr();
987 
988                 if (auto it = func_map.find(f_id); it != func_map.end()) {
989                     // We already performed diff on the current function,
990                     // fetch the result from the cache.
991                     return it->second;
992                 }
993 
994                 auto ret = arg.diff(func_map, p);
995 
996                 // Put the return value in the cache.
997                 [[maybe_unused]] const auto [_, flag] = func_map.insert(std::pair{f_id, ret});
998                 // NOTE: an expression cannot contain itself.
999                 assert(flag);
1000 
1001                 return ret;
1002             }
1003         },
1004         e.value());
1005 }
1006 
1007 } // namespace detail
1008 
diff(const expression & e,const std::string & s)1009 expression diff(const expression &e, const std::string &s)
1010 {
1011     std::unordered_map<const void *, expression> func_map;
1012 
1013     return detail::diff(func_map, e, s);
1014 }
1015 
diff(const expression & e,const param & p)1016 expression diff(const expression &e, const param &p)
1017 {
1018     std::unordered_map<const void *, expression> func_map;
1019 
1020     return detail::diff(func_map, e, p);
1021 }
1022 
diff(const expression & e,const expression & x)1023 expression diff(const expression &e, const expression &x)
1024 {
1025     return std::visit(
1026         [&e](const auto &v) -> expression {
1027             if constexpr (std::is_same_v<detail::uncvref_t<decltype(v)>, variable>) {
1028                 return diff(e, v.name());
1029             } else if constexpr (std::is_same_v<detail::uncvref_t<decltype(v)>, param>) {
1030                 return diff(e, v);
1031             } else {
1032                 throw std::invalid_argument(
1033                     "Derivatives are currently supported only with respect to variables and parameters");
1034             }
1035         },
1036         x.value());
1037 }
1038 
1039 namespace detail
1040 {
1041 
1042 namespace
1043 {
1044 
1045 // NOTE: an in-place API would perform better.
subs(std::unordered_map<const void *,expression> & func_map,const expression & ex,const std::unordered_map<std::string,expression> & smap)1046 expression subs(std::unordered_map<const void *, expression> &func_map, const expression &ex,
1047                 const std::unordered_map<std::string, expression> &smap)
1048 {
1049     return std::visit(
1050         [&func_map, &smap](const auto &arg) {
1051             using type = detail::uncvref_t<decltype(arg)>;
1052 
1053             if constexpr (std::is_same_v<type, number> || std::is_same_v<type, param>) {
1054                 return expression{arg};
1055             } else if constexpr (std::is_same_v<type, variable>) {
1056                 if (auto it = smap.find(arg.name()); it == smap.end()) {
1057                     return expression{arg};
1058                 } else {
1059                     return it->second;
1060                 }
1061             } else {
1062                 const auto f_id = arg.get_ptr();
1063 
1064                 if (auto it = func_map.find(f_id); it != func_map.end()) {
1065                     // We already performed substitution on the current function,
1066                     // fetch the result from the cache.
1067                     return it->second;
1068                 }
1069 
1070                 // NOTE: this creates a separate instance of arg, but its
1071                 // arguments are shallow-copied.
1072                 auto tmp = arg.copy();
1073 
1074                 for (auto [b, e] = tmp.get_mutable_args_it(); b != e; ++b) {
1075                     *b = subs(func_map, *b, smap);
1076                 }
1077 
1078                 // Put the return value in the cache.
1079                 auto ret = expression{std::move(tmp)};
1080                 [[maybe_unused]] const auto [_, flag] = func_map.insert(std::pair{f_id, ret});
1081                 // NOTE: an expression cannot contain itself.
1082                 assert(flag);
1083 
1084                 return ret;
1085             }
1086         },
1087         ex.value());
1088 }
1089 
1090 } // namespace
1091 
1092 } // namespace detail
1093 
subs(const expression & e,const std::unordered_map<std::string,expression> & smap)1094 expression subs(const expression &e, const std::unordered_map<std::string, expression> &smap)
1095 {
1096     std::unordered_map<const void *, expression> func_map;
1097 
1098     return detail::subs(func_map, e, smap);
1099 }
1100 
1101 namespace detail
1102 {
1103 
1104 namespace
1105 {
1106 
1107 // Pairwise reduction of a vector of expressions.
1108 template <typename F>
pairwise_reduce(const F & func,std::vector<expression> list)1109 expression pairwise_reduce(const F &func, std::vector<expression> list)
1110 {
1111     assert(!list.empty());
1112 
1113     // LCOV_EXCL_START
1114     if (list.size() == std::numeric_limits<decltype(list.size())>::max()) {
1115         throw std::overflow_error("Overflow detected in pairwise_reduce()");
1116     }
1117     // LCOV_EXCL_STOP
1118 
1119     while (list.size() != 1u) {
1120         const auto cur_size = list.size();
1121 
1122         // Init the new list. The size will be halved, +1 if the
1123         // current size is odd.
1124         const auto next_size = cur_size / 2u + cur_size % 2u;
1125         std::vector<expression> new_list(next_size);
1126 
1127         tbb::parallel_for(tbb::blocked_range<decltype(new_list.size())>(0, new_list.size()),
1128                           [&list, &new_list, cur_size, &func](const auto &r) {
1129                               for (auto i = r.begin(); i != r.end(); ++i) {
1130                                   if (i * 2u == cur_size - 1u) {
1131                                       // list has an odd size, and we are at the last element of list.
1132                                       // Just move it to new_list.
1133                                       new_list[i] = std::move(list.back());
1134                                   } else {
1135                                       new_list[i] = func(std::move(list[i * 2u]), std::move(list[i * 2u + 1u]));
1136                                   }
1137                               }
1138                           });
1139 
1140         new_list.swap(list);
1141     }
1142 
1143     return std::move(list[0]);
1144 }
1145 
1146 } // namespace
1147 
1148 } // namespace detail
1149 
1150 // Pairwise product.
pairwise_prod(std::vector<expression> prod)1151 expression pairwise_prod(std::vector<expression> prod)
1152 {
1153     if (prod.empty()) {
1154         return 1_dbl;
1155     }
1156 
1157     return detail::pairwise_reduce([](expression &&a, expression &&b) { return std::move(a) * std::move(b); },
1158                                    std::move(prod));
1159 }
1160 
eval_dbl(const expression & e,const std::unordered_map<std::string,double> & map,const std::vector<double> & pars)1161 double eval_dbl(const expression &e, const std::unordered_map<std::string, double> &map,
1162                 const std::vector<double> &pars)
1163 {
1164     return std::visit([&](const auto &arg) { return eval_dbl(arg, map, pars); }, e.value());
1165 }
1166 
eval_ldbl(const expression & e,const std::unordered_map<std::string,long double> & map,const std::vector<long double> & pars)1167 long double eval_ldbl(const expression &e, const std::unordered_map<std::string, long double> &map,
1168                       const std::vector<long double> &pars)
1169 {
1170     return std::visit([&](const auto &arg) { return eval_ldbl(arg, map, pars); }, e.value());
1171 }
1172 
1173 #if defined(HEYOKA_HAVE_REAL128)
eval_f128(const expression & e,const std::unordered_map<std::string,mppp::real128> & map,const std::vector<mppp::real128> & pars)1174 mppp::real128 eval_f128(const expression &e, const std::unordered_map<std::string, mppp::real128> &map,
1175                         const std::vector<mppp::real128> &pars)
1176 {
1177     return std::visit([&](const auto &arg) { return eval_f128(arg, map, pars); }, e.value());
1178 }
1179 #endif
1180 
eval_batch_dbl(std::vector<double> & retval,const expression & e,const std::unordered_map<std::string,std::vector<double>> & map,const std::vector<double> & pars)1181 void eval_batch_dbl(std::vector<double> &retval, const expression &e,
1182                     const std::unordered_map<std::string, std::vector<double>> &map, const std::vector<double> &pars)
1183 {
1184     std::visit([&](const auto &arg) { eval_batch_dbl(retval, arg, map, pars); }, e.value());
1185 }
1186 
compute_connections(const expression & e)1187 std::vector<std::vector<std::size_t>> compute_connections(const expression &e)
1188 {
1189     std::vector<std::vector<std::size_t>> node_connections;
1190     std::size_t node_counter = 0u;
1191     update_connections(node_connections, e, node_counter);
1192     return node_connections;
1193 }
1194 
update_connections(std::vector<std::vector<std::size_t>> & node_connections,const expression & e,std::size_t & node_counter)1195 void update_connections(std::vector<std::vector<std::size_t>> &node_connections, const expression &e,
1196                         std::size_t &node_counter)
1197 {
1198     std::visit([&node_connections,
1199                 &node_counter](const auto &arg) { update_connections(node_connections, arg, node_counter); },
1200                e.value());
1201 }
1202 
compute_node_values_dbl(const expression & e,const std::unordered_map<std::string,double> & map,const std::vector<std::vector<std::size_t>> & node_connections)1203 std::vector<double> compute_node_values_dbl(const expression &e, const std::unordered_map<std::string, double> &map,
1204                                             const std::vector<std::vector<std::size_t>> &node_connections)
1205 {
1206     std::vector<double> node_values(node_connections.size());
1207     std::size_t node_counter = 0u;
1208     update_node_values_dbl(node_values, e, map, node_connections, node_counter);
1209     return node_values;
1210 }
1211 
update_node_values_dbl(std::vector<double> & node_values,const expression & e,const std::unordered_map<std::string,double> & map,const std::vector<std::vector<std::size_t>> & node_connections,std::size_t & node_counter)1212 void update_node_values_dbl(std::vector<double> &node_values, const expression &e,
1213                             const std::unordered_map<std::string, double> &map,
1214                             const std::vector<std::vector<std::size_t>> &node_connections, std::size_t &node_counter)
1215 {
1216     std::visit([&map, &node_values, &node_connections, &node_counter](
1217                    const auto &arg) { update_node_values_dbl(node_values, arg, map, node_connections, node_counter); },
1218                e.value());
1219 }
1220 
compute_grad_dbl(const expression & e,const std::unordered_map<std::string,double> & map,const std::vector<std::vector<std::size_t>> & node_connections)1221 std::unordered_map<std::string, double> compute_grad_dbl(const expression &e,
1222                                                          const std::unordered_map<std::string, double> &map,
1223                                                          const std::vector<std::vector<std::size_t>> &node_connections)
1224 {
1225     std::unordered_map<std::string, double> grad;
1226     auto node_values = compute_node_values_dbl(e, map, node_connections);
1227     std::size_t node_counter = 0u;
1228     update_grad_dbl(grad, e, map, node_values, node_connections, node_counter);
1229     return grad;
1230 }
1231 
update_grad_dbl(std::unordered_map<std::string,double> & grad,const expression & e,const std::unordered_map<std::string,double> & map,const std::vector<double> & node_values,const std::vector<std::vector<std::size_t>> & node_connections,std::size_t & node_counter,double acc)1232 void update_grad_dbl(std::unordered_map<std::string, double> &grad, const expression &e,
1233                      const std::unordered_map<std::string, double> &map, const std::vector<double> &node_values,
1234                      const std::vector<std::vector<std::size_t>> &node_connections, std::size_t &node_counter,
1235                      double acc)
1236 {
1237     std::visit(
1238         [&map, &grad, &node_values, &node_connections, &node_counter, &acc](const auto &arg) {
1239             update_grad_dbl(grad, arg, map, node_values, node_connections, node_counter, acc);
1240         },
1241         e.value());
1242 }
1243 
1244 namespace detail
1245 {
1246 
taylor_decompose(std::unordered_map<const void *,taylor_dc_t::size_type> & func_map,const expression & ex,taylor_dc_t & dc)1247 taylor_dc_t::size_type taylor_decompose(std::unordered_map<const void *, taylor_dc_t::size_type> &func_map,
1248                                         const expression &ex, taylor_dc_t &dc)
1249 {
1250     if (auto fptr = std::get_if<func>(&ex.value())) {
1251         return fptr->taylor_decompose(func_map, dc);
1252     } else {
1253         return 0;
1254     }
1255 }
1256 
1257 } // namespace detail
1258 
1259 // Decompose ex into dc. The return value is the index, in dc,
1260 // which corresponds to the decomposed version of ex.
1261 // If the return value is zero, ex was not decomposed.
taylor_decompose(const expression & ex,taylor_dc_t & dc)1262 taylor_dc_t::size_type taylor_decompose(const expression &ex, taylor_dc_t &dc)
1263 {
1264     std::unordered_map<const void *, taylor_dc_t::size_type> func_map;
1265 
1266     return detail::taylor_decompose(func_map, ex, dc);
1267 }
1268 
1269 namespace detail
1270 {
1271 
1272 namespace
1273 {
1274 
1275 template <typename T>
taylor_diff_impl(llvm_state & s,const expression & ex,const std::vector<std::uint32_t> & deps,const std::vector<llvm::Value * > & arr,llvm::Value * par_ptr,llvm::Value * time_ptr,std::uint32_t n_uvars,std::uint32_t order,std::uint32_t idx,std::uint32_t batch_size,bool high_accuracy)1276 llvm::Value *taylor_diff_impl(llvm_state &s, const expression &ex, const std::vector<std::uint32_t> &deps,
1277                               const std::vector<llvm::Value *> &arr, llvm::Value *par_ptr, llvm::Value *time_ptr,
1278                               std::uint32_t n_uvars, std::uint32_t order, std::uint32_t idx, std::uint32_t batch_size,
1279                               bool high_accuracy)
1280 {
1281     if (auto fptr = std::get_if<func>(&ex.value())) {
1282         if constexpr (std::is_same_v<T, double>) {
1283             return fptr->taylor_diff_dbl(s, deps, arr, par_ptr, time_ptr, n_uvars, order, idx, batch_size,
1284                                          high_accuracy);
1285         } else if constexpr (std::is_same_v<T, long double>) {
1286             return fptr->taylor_diff_ldbl(s, deps, arr, par_ptr, time_ptr, n_uvars, order, idx, batch_size,
1287                                           high_accuracy);
1288 #if defined(HEYOKA_HAVE_REAL128)
1289         } else if constexpr (std::is_same_v<T, mppp::real128>) {
1290             return fptr->taylor_diff_f128(s, deps, arr, par_ptr, time_ptr, n_uvars, order, idx, batch_size,
1291                                           high_accuracy);
1292 #endif
1293         } else {
1294             static_assert(detail::always_false_v<T>, "Unhandled type.");
1295         }
1296     } else {
1297         // LCOV_EXCL_START
1298         throw std::invalid_argument("Taylor derivatives can be computed only for functions");
1299         // LCOV_EXCL_STOP
1300     }
1301 }
1302 
1303 } // namespace
1304 
1305 } // namespace detail
1306 
taylor_diff_dbl(llvm_state & s,const expression & ex,const std::vector<std::uint32_t> & deps,const std::vector<llvm::Value * > & arr,llvm::Value * par_ptr,llvm::Value * time_ptr,std::uint32_t n_uvars,std::uint32_t order,std::uint32_t idx,std::uint32_t batch_size,bool high_accuracy)1307 llvm::Value *taylor_diff_dbl(llvm_state &s, const expression &ex, const std::vector<std::uint32_t> &deps,
1308                              const std::vector<llvm::Value *> &arr, llvm::Value *par_ptr, llvm::Value *time_ptr,
1309                              std::uint32_t n_uvars, std::uint32_t order, std::uint32_t idx, std::uint32_t batch_size,
1310                              bool high_accuracy)
1311 
1312 {
1313     return detail::taylor_diff_impl<double>(s, ex, deps, arr, par_ptr, time_ptr, n_uvars, order, idx, batch_size,
1314                                             high_accuracy);
1315 }
1316 
taylor_diff_ldbl(llvm_state & s,const expression & ex,const std::vector<std::uint32_t> & deps,const std::vector<llvm::Value * > & arr,llvm::Value * par_ptr,llvm::Value * time_ptr,std::uint32_t n_uvars,std::uint32_t order,std::uint32_t idx,std::uint32_t batch_size,bool high_accuracy)1317 llvm::Value *taylor_diff_ldbl(llvm_state &s, const expression &ex, const std::vector<std::uint32_t> &deps,
1318                               const std::vector<llvm::Value *> &arr, llvm::Value *par_ptr, llvm::Value *time_ptr,
1319                               std::uint32_t n_uvars, std::uint32_t order, std::uint32_t idx, std::uint32_t batch_size,
1320                               bool high_accuracy)
1321 {
1322     return detail::taylor_diff_impl<long double>(s, ex, deps, arr, par_ptr, time_ptr, n_uvars, order, idx, batch_size,
1323                                                  high_accuracy);
1324 }
1325 
1326 #if defined(HEYOKA_HAVE_REAL128)
1327 
taylor_diff_f128(llvm_state & s,const expression & ex,const std::vector<std::uint32_t> & deps,const std::vector<llvm::Value * > & arr,llvm::Value * par_ptr,llvm::Value * time_ptr,std::uint32_t n_uvars,std::uint32_t order,std::uint32_t idx,std::uint32_t batch_size,bool high_accuracy)1328 llvm::Value *taylor_diff_f128(llvm_state &s, const expression &ex, const std::vector<std::uint32_t> &deps,
1329                               const std::vector<llvm::Value *> &arr, llvm::Value *par_ptr, llvm::Value *time_ptr,
1330                               std::uint32_t n_uvars, std::uint32_t order, std::uint32_t idx, std::uint32_t batch_size,
1331                               bool high_accuracy)
1332 {
1333     return detail::taylor_diff_impl<mppp::real128>(s, ex, deps, arr, par_ptr, time_ptr, n_uvars, order, idx, batch_size,
1334                                                    high_accuracy);
1335 }
1336 
1337 #endif
1338 
1339 namespace detail
1340 {
1341 
1342 namespace
1343 {
1344 
1345 template <typename T>
taylor_c_diff_func_impl(llvm_state & s,const expression & ex,std::uint32_t n_uvars,std::uint32_t batch_size,bool high_accuracy)1346 llvm::Function *taylor_c_diff_func_impl(llvm_state &s, const expression &ex, std::uint32_t n_uvars,
1347                                         std::uint32_t batch_size, bool high_accuracy)
1348 {
1349     if (auto fptr = std::get_if<func>(&ex.value())) {
1350         if constexpr (std::is_same_v<T, double>) {
1351             return fptr->taylor_c_diff_func_dbl(s, n_uvars, batch_size, high_accuracy);
1352         } else if constexpr (std::is_same_v<T, long double>) {
1353             return fptr->taylor_c_diff_func_ldbl(s, n_uvars, batch_size, high_accuracy);
1354 #if defined(HEYOKA_HAVE_REAL128)
1355         } else if constexpr (std::is_same_v<T, mppp::real128>) {
1356             return fptr->taylor_c_diff_func_f128(s, n_uvars, batch_size, high_accuracy);
1357 #endif
1358         } else {
1359             static_assert(detail::always_false_v<T>, "Unhandled type.");
1360         }
1361     } else {
1362         // LCOV_EXCL_START
1363         throw std::invalid_argument("Taylor derivatives in compact mode can be computed only for functions");
1364         // LCOV_EXCL_STOP
1365     }
1366 }
1367 
1368 } // namespace
1369 
1370 } // namespace detail
1371 
taylor_c_diff_func_dbl(llvm_state & s,const expression & ex,std::uint32_t n_uvars,std::uint32_t batch_size,bool high_accuracy)1372 llvm::Function *taylor_c_diff_func_dbl(llvm_state &s, const expression &ex, std::uint32_t n_uvars,
1373                                        std::uint32_t batch_size, bool high_accuracy)
1374 {
1375     return detail::taylor_c_diff_func_impl<double>(s, ex, n_uvars, batch_size, high_accuracy);
1376 }
1377 
taylor_c_diff_func_ldbl(llvm_state & s,const expression & ex,std::uint32_t n_uvars,std::uint32_t batch_size,bool high_accuracy)1378 llvm::Function *taylor_c_diff_func_ldbl(llvm_state &s, const expression &ex, std::uint32_t n_uvars,
1379                                         std::uint32_t batch_size, bool high_accuracy)
1380 {
1381     return detail::taylor_c_diff_func_impl<long double>(s, ex, n_uvars, batch_size, high_accuracy);
1382 }
1383 
1384 #if defined(HEYOKA_HAVE_REAL128)
1385 
taylor_c_diff_func_f128(llvm_state & s,const expression & ex,std::uint32_t n_uvars,std::uint32_t batch_size,bool high_accuracy)1386 llvm::Function *taylor_c_diff_func_f128(llvm_state &s, const expression &ex, std::uint32_t n_uvars,
1387                                         std::uint32_t batch_size, bool high_accuracy)
1388 {
1389     return detail::taylor_c_diff_func_impl<mppp::real128>(s, ex, n_uvars, batch_size, high_accuracy);
1390 }
1391 
1392 #endif
1393 
1394 namespace detail
1395 {
1396 
1397 // Helper to detect if ex is an integral number.
is_integral(const expression & ex)1398 bool is_integral(const expression &ex)
1399 {
1400     return std::visit(
1401         [](const auto &v) {
1402             using type = detail::uncvref_t<decltype(v)>;
1403 
1404             if constexpr (std::is_same_v<type, number>) {
1405                 return std::visit(
1406                     [](const auto &x) {
1407                         using std::trunc;
1408                         using std::isfinite;
1409 
1410                         return isfinite(x) && x == trunc(x);
1411                     },
1412                     v.value());
1413             } else {
1414                 // Not a number.
1415                 return false;
1416             }
1417         },
1418         ex.value());
1419 }
1420 
1421 // Helper to detect if ex is a number in the form n / 2,
1422 // where n is an odd integral value.
is_odd_integral_half(const expression & ex)1423 bool is_odd_integral_half(const expression &ex)
1424 {
1425     return std::visit(
1426         [](const auto &v) {
1427             using type = detail::uncvref_t<decltype(v)>;
1428 
1429             if constexpr (std::is_same_v<type, number>) {
1430                 return std::visit(
1431                     [](const auto &x) {
1432                         using std::trunc;
1433                         using std::isfinite;
1434 
1435                         if (!isfinite(x) || x == trunc(x)) {
1436                             // x is not finite, or it is already
1437                             // an integral value.
1438                             return false;
1439                         }
1440 
1441                         const auto y = 2 * x;
1442 
1443                         return isfinite(y) && y == trunc(y);
1444                     },
1445                     v.value());
1446             } else {
1447                 // Not a number.
1448                 return false;
1449             }
1450         },
1451         ex.value());
1452 }
1453 
operator [](std::uint32_t idx) const1454 expression par_impl::operator[](std::uint32_t idx) const
1455 {
1456     return expression{param{idx}};
1457 }
1458 
1459 } // namespace detail
1460 
1461 namespace detail
1462 {
1463 
1464 namespace
1465 {
1466 
get_param_size(std::unordered_set<const void * > & func_set,const expression & ex)1467 std::uint32_t get_param_size(std::unordered_set<const void *> &func_set, const expression &ex)
1468 {
1469     std::uint32_t retval = 0;
1470 
1471     std::visit(
1472         [&retval, &func_set](const auto &v) {
1473             using type = detail::uncvref_t<decltype(v)>;
1474 
1475             if constexpr (std::is_same_v<type, param>) {
1476                 if (v.idx() == std::numeric_limits<std::uint32_t>::max()) {
1477                     throw std::overflow_error("Overflow dected in get_n_param()");
1478                 }
1479 
1480                 retval = std::max(static_cast<std::uint32_t>(v.idx() + 1u), retval);
1481             } else if constexpr (std::is_same_v<type, func>) {
1482                 const auto f_id = v.get_ptr();
1483 
1484                 if (auto it = func_set.find(f_id); it != func_set.end()) {
1485                     // We already computed the number of params for the current
1486                     // function, exit.
1487                     return;
1488                 }
1489 
1490                 for (const auto &a : v.args()) {
1491                     retval = std::max(get_param_size(func_set, a), retval);
1492                 }
1493 
1494                 // Update the cache.
1495                 [[maybe_unused]] const auto [_, flag] = func_set.insert(f_id);
1496                 // NOTE: an expression cannot contain itself.
1497                 assert(flag);
1498             }
1499         },
1500         ex.value());
1501 
1502     return retval;
1503 }
1504 
1505 } // namespace
1506 
1507 } // namespace detail
1508 
1509 // Determine the size of the parameter vector from the highest
1510 // param index appearing in an expression. If the return value
1511 // is zero, no params appear in the expression.
get_param_size(const expression & ex)1512 std::uint32_t get_param_size(const expression &ex)
1513 {
1514     std::unordered_set<const void *> func_set;
1515 
1516     return detail::get_param_size(func_set, ex);
1517 }
1518 
1519 namespace detail
1520 {
1521 
1522 namespace
1523 {
1524 
has_time(std::unordered_set<const void * > & func_set,const expression & ex)1525 bool has_time(std::unordered_set<const void *> &func_set, const expression &ex)
1526 {
1527     // If the expression itself is a time function or a tpoly,
1528     // return true.
1529     if (detail::is_time(ex) || detail::is_tpoly(ex)) {
1530         return true;
1531     }
1532 
1533     // Otherwise:
1534     // - if ex is a function, check if any of its arguments
1535     //   is time-dependent,
1536     // - otherwise, return false.
1537     return std::visit(
1538         [&func_set](const auto &v) {
1539             using type = detail::uncvref_t<decltype(v)>;
1540 
1541             if constexpr (std::is_same_v<type, func>) {
1542                 const auto f_id = v.get_ptr();
1543 
1544                 if (auto it = func_set.find(f_id); it != func_set.end()) {
1545                     // We already determined if this function contains time,
1546                     // return false (if the function does contain time, the first
1547                     // time it was encountered we returned true and we could not
1548                     // possibly end up here).
1549                     return false;
1550                 }
1551 
1552                 // Update the cache.
1553                 // NOTE: do it earlier than usual in order to avoid having
1554                 // to repeat this code twice for the two paths below.
1555                 func_set.insert(f_id);
1556 
1557                 for (const auto &a : v.args()) {
1558                     if (has_time(func_set, a)) {
1559                         return true;
1560                     }
1561                 }
1562             }
1563 
1564             return false;
1565         },
1566         ex.value());
1567 }
1568 
1569 } // namespace
1570 
1571 } // namespace detail
1572 
1573 // Determine if an expression is time-dependent.
has_time(const expression & ex)1574 bool has_time(const expression &ex)
1575 {
1576     std::unordered_set<const void *> func_set;
1577 
1578     return detail::has_time(func_set, ex);
1579 }
1580 
1581 } // namespace heyoka
1582