1 // Copyright 2020, 2021 Francesco Biscani (bluescarni@gmail.com), Dario Izzo (dario.izzo@gmail.com)
2 //
3 // This file is part of the heyoka library.
4 //
5 // This Source Code Form is subject to the terms of the Mozilla
6 // Public License v. 2.0. If a copy of the MPL was not distributed
7 // with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
8
9 #include <heyoka/config.hpp>
10
11 #include <cmath>
12 #include <cstddef>
13 #include <cstdint>
14 #include <limits>
15 #include <ostream>
16 #include <set>
17 #include <stdexcept>
18 #include <string>
19 #include <type_traits>
20 #include <unordered_map>
21 #include <unordered_set>
22 #include <utility>
23 #include <variant>
24 #include <vector>
25
26 #include <tbb/blocked_range.h>
27 #include <tbb/parallel_for.h>
28
29 #include <llvm/IR/Function.h>
30 #include <llvm/IR/Value.h>
31
32 #include <fmt/format.h>
33 #include <fmt/ostream.h>
34
35 #if defined(HEYOKA_HAVE_REAL128)
36
37 #include <mp++/real128.hpp>
38
39 #endif
40
41 #include <heyoka/detail/llvm_fwd.hpp>
42 #include <heyoka/detail/type_traits.hpp>
43 #include <heyoka/exceptions.hpp>
44 #include <heyoka/expression.hpp>
45 #include <heyoka/func.hpp>
46 #include <heyoka/llvm_state.hpp>
47 #include <heyoka/math/binary_op.hpp>
48 #include <heyoka/math/neg.hpp>
49 #include <heyoka/math/square.hpp>
50 #include <heyoka/math/sum.hpp>
51 #include <heyoka/math/time.hpp>
52 #include <heyoka/math/tpoly.hpp>
53 #include <heyoka/number.hpp>
54 #include <heyoka/param.hpp>
55 #include <heyoka/variable.hpp>
56
57 #if defined(_MSC_VER) && !defined(__clang__)
58
59 // NOTE: MSVC has issues with the other "using"
60 // statement form.
61 using namespace fmt::literals;
62
63 #else
64
65 using fmt::literals::operator""_format;
66
67 #endif
68
69 namespace heyoka
70 {
71
expression()72 expression::expression() : expression(number{0.}) {}
73
expression(double x)74 expression::expression(double x) : expression(number{x}) {}
75
expression(long double x)76 expression::expression(long double x) : expression(number{x}) {}
77
78 #if defined(HEYOKA_HAVE_REAL128)
79
expression(mppp::real128 x)80 expression::expression(mppp::real128 x) : expression(number{x}) {}
81
82 #endif
83
expression(std::string s)84 expression::expression(std::string s) : expression(variable{std::move(s)}) {}
85
expression(number n)86 expression::expression(number n) : m_value(std::move(n)) {}
87
expression(variable var)88 expression::expression(variable var) : m_value(std::move(var)) {}
89
expression(func f)90 expression::expression(func f) : m_value(std::move(f)) {}
91
expression(param p)92 expression::expression(param p) : m_value(std::move(p)) {}
93
94 expression::expression(const expression &) = default;
95
96 expression::expression(expression &&) noexcept = default;
97
98 expression::~expression() = default;
99
100 expression &expression::operator=(const expression &) = default;
101
102 expression &expression::operator=(expression &&) noexcept = default;
103
value()104 expression::value_type &expression::value()
105 {
106 return m_value;
107 }
108
value() const109 const expression::value_type &expression::value() const
110 {
111 return m_value;
112 }
113
114 namespace detail
115 {
116
117 namespace
118 {
119
copy(std::unordered_map<const void *,expression> & func_map,const expression & e)120 expression copy(std::unordered_map<const void *, expression> &func_map, const expression &e)
121 {
122 return std::visit(
123 [&func_map](const auto &v) {
124 if constexpr (std::is_same_v<detail::uncvref_t<decltype(v)>, func>) {
125 const auto f_id = v.get_ptr();
126
127 if (auto it = func_map.find(f_id); it != func_map.end()) {
128 // We already copied the current function, fetch the copy
129 // from the cache.
130 return it->second;
131 }
132
133 // Create a copy of v. Note that this will copy
134 // the arguments of v via their copy constructor,
135 // and thus any argument which is itself a function
136 // will be shallow-copied.
137 auto f_copy = v.copy();
138
139 // Perform a copy of the arguments of v which are functions.
140 assert(v.args().size() == f_copy.args().size()); // LCOV_EXCL_LINE
141 auto b1 = v.args().begin();
142 for (auto [b2, e2] = f_copy.get_mutable_args_it(); b2 != e2; ++b1, ++b2) {
143 // NOTE: the argument needs to be copied via a recursive
144 // call to copy() only if it is a func. Otherwise, the copy
145 // we made earlier via the copy constructor is already a deep copy.
146 if (std::holds_alternative<func>(b1->value())) {
147 *b2 = copy(func_map, *b1);
148 }
149 }
150
151 // Construct the return value and put it into the cache.
152 auto ex = expression{std::move(f_copy)};
153 [[maybe_unused]] const auto [_, flag] = func_map.insert(std::pair{f_id, ex});
154 // NOTE: an expression cannot contain itself.
155 assert(flag); // LCOV_EXCL_LINE
156
157 return ex;
158 } else {
159 return expression{v};
160 }
161 },
162 e.value());
163 }
164
165 } // namespace
166
167 } // namespace detail
168
copy(const expression & e)169 expression copy(const expression &e)
170 {
171 std::unordered_map<const void *, expression> func_map;
172
173 return detail::copy(func_map, e);
174 }
175
176 inline namespace literals
177 {
178
operator ""_dbl(long double x)179 expression operator""_dbl(long double x)
180 {
181 return expression{static_cast<double>(x)};
182 }
183
operator ""_dbl(unsigned long long n)184 expression operator""_dbl(unsigned long long n)
185 {
186 return expression{static_cast<double>(n)};
187 }
188
operator ""_ldbl(long double x)189 expression operator""_ldbl(long double x)
190 {
191 return expression{x};
192 }
193
operator ""_ldbl(unsigned long long n)194 expression operator""_ldbl(unsigned long long n)
195 {
196 return expression{static_cast<long double>(n)};
197 }
198
operator ""_var(const char * s,std::size_t n)199 expression operator""_var(const char *s, std::size_t n)
200 {
201 return expression{variable{std::string(s, n)}};
202 }
203
204 } // namespace literals
205
206 namespace detail
207 {
208
prime_wrapper(std::string s)209 prime_wrapper::prime_wrapper(std::string s) : m_str(std::move(s)) {}
210
211 prime_wrapper::prime_wrapper(const prime_wrapper &) = default;
212
213 prime_wrapper::prime_wrapper(prime_wrapper &&) noexcept = default;
214
215 prime_wrapper &prime_wrapper::operator=(const prime_wrapper &) = default;
216
217 prime_wrapper &prime_wrapper::operator=(prime_wrapper &&) noexcept = default;
218
219 prime_wrapper::~prime_wrapper() = default;
220
operator =(expression e)221 std::pair<expression, expression> prime_wrapper::operator=(expression e) &&
222 {
223 return std::pair{expression{variable{std::move(m_str)}}, std::move(e)};
224 }
225
226 } // namespace detail
227
prime(expression e)228 detail::prime_wrapper prime(expression e)
229 {
230 return std::visit(
231 [&e](auto &v) -> detail::prime_wrapper {
232 if constexpr (std::is_same_v<variable, detail::uncvref_t<decltype(v)>>) {
233 return detail::prime_wrapper{std::move(v.name())};
234 } else {
235 throw std::invalid_argument(
236 "Cannot apply the prime() operator to the non-variable expression '{}'"_format(e));
237 }
238 },
239 e.value());
240 }
241
242 inline namespace literals
243 {
244
operator ""_p(const char * s,std::size_t n)245 detail::prime_wrapper operator""_p(const char *s, std::size_t n)
246 {
247 return detail::prime_wrapper{std::string(s, n)};
248 }
249
250 } // namespace literals
251
252 namespace detail
253 {
254
255 namespace
256 {
257
get_variables(std::unordered_set<const void * > & func_set,std::set<std::string> & s_set,const expression & e)258 void get_variables(std::unordered_set<const void *> &func_set, std::set<std::string> &s_set, const expression &e)
259 {
260 std::visit(
261 [&func_set, &s_set](const auto &arg) {
262 using type = detail::uncvref_t<decltype(arg)>;
263
264 if constexpr (std::is_same_v<type, func>) {
265 const auto f_id = arg.get_ptr();
266
267 if (func_set.find(f_id) != func_set.end()) {
268 // We already determined the list of variables for the
269 // current function, exit.
270 return;
271 }
272
273 // Determine the list of variables for each
274 // function argument.
275 for (const auto &farg : arg.args()) {
276 get_variables(func_set, s_set, farg);
277 }
278
279 // Add the id of f to the set.
280 [[maybe_unused]] const auto [_, flag] = func_set.insert(f_id);
281 // NOTE: an expression cannot contain itself.
282 assert(flag);
283 } else if constexpr (std::is_same_v<type, variable>) {
284 s_set.insert(arg.name());
285 }
286 },
287 e.value());
288 }
289
rename_variables(std::unordered_set<const void * > & func_set,expression & e,const std::unordered_map<std::string,std::string> & repl_map)290 void rename_variables(std::unordered_set<const void *> &func_set, expression &e,
291 const std::unordered_map<std::string, std::string> &repl_map)
292 {
293 std::visit(
294 [&func_set, &repl_map](auto &arg) {
295 using type = detail::uncvref_t<decltype(arg)>;
296
297 if constexpr (std::is_same_v<type, func>) {
298 const auto f_id = arg.get_ptr();
299
300 if (func_set.find(f_id) != func_set.end()) {
301 // We already renamed variables for the current function,
302 // just return.
303 return;
304 }
305
306 for (auto [b, e] = arg.get_mutable_args_it(); b != e; ++b) {
307 rename_variables(func_set, *b, repl_map);
308 }
309
310 // Add the id of f to the set.
311 [[maybe_unused]] const auto [_, flag] = func_set.insert(f_id);
312 // NOTE: an expression cannot contain itself.
313 assert(flag);
314 } else if constexpr (std::is_same_v<type, variable>) {
315 if (auto it = repl_map.find(arg.name()); it != repl_map.end()) {
316 arg.name() = it->second;
317 }
318 }
319 },
320 e.value());
321 }
322
323 } // namespace
324
325 } // namespace detail
326
get_variables(const expression & e)327 std::vector<std::string> get_variables(const expression &e)
328 {
329 std::unordered_set<const void *> func_set;
330 std::set<std::string> s_set;
331
332 detail::get_variables(func_set, s_set, e);
333
334 return std::vector<std::string>(s_set.begin(), s_set.end());
335 }
336
rename_variables(expression & e,const std::unordered_map<std::string,std::string> & repl_map)337 void rename_variables(expression &e, const std::unordered_map<std::string, std::string> &repl_map)
338 {
339 std::unordered_set<const void *> func_set;
340
341 detail::rename_variables(func_set, e, repl_map);
342 }
343
swap(expression & ex0,expression & ex1)344 void swap(expression &ex0, expression &ex1) noexcept
345 {
346 std::swap(ex0.value(), ex1.value());
347 }
348
349 // NOTE: this implementation does not take advantage of potentially
350 // repeating subexpressions. This is not currently a problem because
351 // hashing is needed only in the CSE for the decomposition, which involves
352 // only trivial expressions. However, this would likely be needed by a to_sympy()
353 // implementation in heyoka.py which allows for a dictionary of custom
354 // substitutions to be provided by the user.
hash(const expression & ex)355 std::size_t hash(const expression &ex)
356 {
357 return std::visit([](const auto &v) { return hash(v); }, ex.value());
358 }
359
operator <<(std::ostream & os,const expression & e)360 std::ostream &operator<<(std::ostream &os, const expression &e)
361 {
362 return std::visit([&os](const auto &arg) -> std::ostream & { return os << arg; }, e.value());
363 }
364
operator +(expression e)365 expression operator+(expression e)
366 {
367 return e;
368 }
369
operator -(expression e)370 expression operator-(expression e)
371 {
372 if (auto num_ptr = std::get_if<number>(&e.value())) {
373 // Simplify -number to its numerical value.
374 return expression{-std::move(*num_ptr)};
375 } else {
376 if (auto fptr = detail::is_neg(e)) {
377 // Simplify -(-x) to x.
378 assert(!fptr->args().empty()); // LCOV_EXCL_LINE
379 return fptr->args()[0];
380 } else {
381 return neg(std::move(e));
382 }
383 }
384 }
385
operator +(expression e1,expression e2)386 expression operator+(expression e1, expression e2)
387 {
388 // Simplify x + neg(y) to x - y.
389 if (auto fptr = detail::is_neg(e2)) {
390 assert(!fptr->args().empty()); // LCOV_EXCL_LINE
391 return std::move(e1) - fptr->args()[0];
392 }
393
394 auto visitor = [](auto &&v1, auto &&v2) {
395 using type1 = detail::uncvref_t<decltype(v1)>;
396 using type2 = detail::uncvref_t<decltype(v2)>;
397
398 if constexpr (std::is_same_v<type1, number> && std::is_same_v<type2, number>) {
399 // Both are numbers, add them and return the result.
400 return expression{std::forward<decltype(v1)>(v1) + std::forward<decltype(v2)>(v2)};
401 } else if constexpr (std::is_same_v<type1, number>) {
402 // e1 number, e2 symbolic.
403 if (is_zero(v1)) {
404 // 0 + e2 = e2.
405 return expression{std::forward<decltype(v2)>(v2)};
406 }
407 if constexpr (std::is_same_v<func, type2>) {
408 if (auto pbop = v2.template extract<detail::binary_op>();
409 pbop != nullptr && pbop->op() == detail::binary_op::type::add
410 && std::holds_alternative<number>(pbop->args()[0].value())) {
411 // e2 = a + x, where a is a number. Simplify e1 + (a + x) -> c + x, where c = e1 + a.
412 return expression{std::forward<decltype(v1)>(v1)} + pbop->args()[0] + pbop->args()[1];
413 }
414 }
415
416 // NOTE: fall through the standard case.
417 } else if constexpr (std::is_same_v<type2, number>) {
418 // e1 symbolic, e2 number. Swap the operands so that the number comes first.
419 return expression{std::forward<decltype(v2)>(v2)} + expression{std::forward<decltype(v1)>(v1)};
420 }
421
422 // The standard case.
423 return add(expression{std::forward<decltype(v1)>(v1)}, expression{std::forward<decltype(v2)>(v2)});
424 };
425
426 return std::visit(visitor, std::move(e1.value()), std::move(e2.value()));
427 }
428
operator -(expression e1,expression e2)429 expression operator-(expression e1, expression e2)
430 {
431 // Simplify x - (-y) to x + y.
432 if (auto fptr = detail::is_neg(e2)) {
433 assert(!fptr->args().empty()); // LCOV_EXCL_LINE
434 return std::move(e1) + fptr->args()[0];
435 }
436
437 auto visitor = [](auto &&v1, auto &&v2) {
438 using type1 = detail::uncvref_t<decltype(v1)>;
439 using type2 = detail::uncvref_t<decltype(v2)>;
440
441 if constexpr (std::is_same_v<type1, number> && std::is_same_v<type2, number>) {
442 // Both are numbers, subtract them.
443 return expression{std::forward<decltype(v1)>(v1) - std::forward<decltype(v2)>(v2)};
444 } else if constexpr (std::is_same_v<type1, number>) {
445 // e1 number, e2 symbolic.
446 if (is_zero(v1)) {
447 // 0 - e2 = -e2.
448 return -expression{std::forward<decltype(v2)>(v2)};
449 }
450 // NOTE: fall through the standard case if e1 is not zero.
451 } else if constexpr (std::is_same_v<type2, number>) {
452 // e1 symbolic, e2 number. Turn e1 - e2 into e1 + (-e2),
453 // because addition provides more simplification capabilities.
454 return expression{std::forward<decltype(v1)>(v1)} + expression{-std::forward<decltype(v2)>(v2)};
455 }
456
457 // The standard case.
458 return sub(expression{std::forward<decltype(v1)>(v1)}, expression{std::forward<decltype(v2)>(v2)});
459 };
460
461 return std::visit(visitor, std::move(e1.value()), std::move(e2.value()));
462 }
463
operator *(expression e1,expression e2)464 expression operator*(expression e1, expression e2)
465 {
466 auto fptr1 = detail::is_neg(e1);
467 auto fptr2 = detail::is_neg(e2);
468
469 if (fptr1 != nullptr && fptr2 != nullptr) {
470 // Simplify (-x) * (-y) into x*y.
471 assert(!fptr1->args().empty()); // LCOV_EXCL_LINE
472 assert(!fptr2->args().empty()); // LCOV_EXCL_LINE
473 return fptr1->args()[0] * fptr2->args()[0];
474 }
475
476 // Simplify x*x -> square(x) if x is not a number (otherwise,
477 // we will numerically compute the result below).
478 if (e1 == e2 && !std::holds_alternative<number>(e1.value())) {
479 return square(std::move(e1));
480 }
481
482 auto visitor = [fptr2](auto &&v1, auto &&v2) {
483 using type1 = detail::uncvref_t<decltype(v1)>;
484 using type2 = detail::uncvref_t<decltype(v2)>;
485
486 if constexpr (std::is_same_v<type1, number> && std::is_same_v<type2, number>) {
487 // Both are numbers, multiply them.
488 return expression{std::forward<decltype(v1)>(v1) * std::forward<decltype(v2)>(v2)};
489 } else if constexpr (std::is_same_v<type1, number>) {
490 // e1 number, e2 symbolic.
491 if (is_zero(v1)) {
492 // 0 * e2 = 0.
493 return expression{number{0.}};
494 }
495 if (is_one(v1)) {
496 // 1 * e2 = e2.
497 return expression{std::forward<decltype(v2)>(v2)};
498 }
499 if (is_negative_one(v1)) {
500 // -1 * e2 = -e2.
501 return -expression{std::forward<decltype(v2)>(v2)};
502 }
503 if (fptr2 != nullptr) {
504 // a * (-x) = (-a) * x.
505 assert(!fptr2->args().empty()); // LCOV_EXCL_LINE
506 return expression{-std::forward<decltype(v1)>(v1)} * fptr2->args()[0];
507 }
508 if constexpr (std::is_same_v<func, type2>) {
509 if (auto pbop = v2.template extract<detail::binary_op>();
510 pbop != nullptr && pbop->op() == detail::binary_op::type::mul
511 && std::holds_alternative<number>(pbop->args()[0].value())) {
512 // e2 = a * x, where a is a number. Simplify e1 * (a * x) -> c * x, where c = e1 * a.
513 return expression{std::forward<decltype(v1)>(v1)} * pbop->args()[0] * pbop->args()[1];
514 }
515 }
516
517 // NOTE: fall through the standard case.
518 } else if constexpr (std::is_same_v<type2, number>) {
519 // e1 symbolic, e2 number. Swap the operands so that the number comes first.
520 return expression{std::forward<decltype(v2)>(v2)} * expression{std::forward<decltype(v1)>(v1)};
521 }
522
523 // The standard case.
524 return mul(expression{std::forward<decltype(v1)>(v1)}, expression{std::forward<decltype(v2)>(v2)});
525 };
526
527 return std::visit(visitor, std::move(e1.value()), std::move(e2.value()));
528 }
529
operator /(expression e1,expression e2)530 expression operator/(expression e1, expression e2)
531 {
532 auto fptr1 = detail::is_neg(e1);
533 auto fptr2 = detail::is_neg(e2);
534
535 if (fptr1 != nullptr && fptr2 != nullptr) {
536 // Simplify (-x) / (-y) into x/y.
537 assert(!fptr1->args().empty()); // LCOV_EXCL_LINE
538 assert(!fptr2->args().empty()); // LCOV_EXCL_LINE
539 return fptr1->args()[0] / fptr2->args()[0];
540 }
541
542 auto visitor = [fptr1, fptr2](auto &&v1, auto &&v2) {
543 using type1 = detail::uncvref_t<decltype(v1)>;
544 using type2 = detail::uncvref_t<decltype(v2)>;
545
546 if constexpr (std::is_same_v<type2, number>) {
547 // If the divisor is zero, always raise an error.
548 if (is_zero(v2)) {
549 throw zero_division_error("Division by zero");
550 }
551 }
552
553 if constexpr (std::is_same_v<type1, number> && std::is_same_v<type2, number>) {
554 // Both are numbers, divide them.
555 return expression{std::forward<decltype(v1)>(v1) / std::forward<decltype(v2)>(v2)};
556 } else if constexpr (std::is_same_v<type2, number>) {
557 // e1 is symbolic, e2 a number.
558 if (is_one(v2)) {
559 // e1 / 1 = e1.
560 return expression{std::forward<decltype(v1)>(v1)};
561 }
562 if (is_negative_one(v2)) {
563 // e1 / -1 = -e1.
564 return -expression{std::forward<decltype(v1)>(v1)};
565 }
566 if (fptr1 != nullptr) {
567 // (-e1) / a = e1 / (-a).
568 assert(!fptr1->args().empty()); // LCOV_EXCL_LINE
569 return fptr1->args()[0] / expression{-std::forward<decltype(v2)>(v2)};
570 }
571 if constexpr (std::is_same_v<func, type1>) {
572 if (auto pbop = v1.template extract<detail::binary_op>();
573 pbop != nullptr && pbop->op() == detail::binary_op::type::div
574 && std::holds_alternative<number>(pbop->args()[1].value())) {
575 // e1 = x / a, where a is a number. Simplify (x / a) / b -> x / (a * b).
576 return pbop->args()[0] / (pbop->args()[1] * expression{std::forward<decltype(v2)>(v2)});
577 }
578 }
579
580 // NOTE: fall through to the standard case.
581 } else if constexpr (std::is_same_v<type1, number>) {
582 // e1 is a number, e2 is symbolic.
583 if (is_zero(v1)) {
584 // 0 / e2 == 0.
585 return expression{number{0.}};
586 }
587 if (fptr2 != nullptr) {
588 // a / (-e2) = (-a) / e2.
589 assert(!fptr2->args().empty()); // LCOV_EXCL_LINE
590 return expression{-std::forward<decltype(v1)>(v1)} / fptr2->args()[0];
591 }
592
593 // NOTE: fall through to the standard case.
594 }
595
596 // The standard case.
597 return div(expression{std::forward<decltype(v1)>(v1)}, expression{std::forward<decltype(v2)>(v2)});
598 };
599
600 return std::visit(visitor, std::move(e1.value()), std::move(e2.value()));
601 }
602
operator +(expression ex,double x)603 expression operator+(expression ex, double x)
604 {
605 return std::move(ex) + expression{x};
606 }
607
operator +(expression ex,long double x)608 expression operator+(expression ex, long double x)
609 {
610 return std::move(ex) + expression{x};
611 }
612
613 #if defined(HEYOKA_HAVE_REAL128)
614
operator +(expression ex,mppp::real128 x)615 expression operator+(expression ex, mppp::real128 x)
616 {
617 return std::move(ex) + expression{x};
618 }
619
620 #endif
621
operator +(double x,expression ex)622 expression operator+(double x, expression ex)
623 {
624 return expression{x} + std::move(ex);
625 }
626
operator +(long double x,expression ex)627 expression operator+(long double x, expression ex)
628 {
629 return expression{x} + std::move(ex);
630 }
631
632 #if defined(HEYOKA_HAVE_REAL128)
633
operator +(mppp::real128 x,expression ex)634 expression operator+(mppp::real128 x, expression ex)
635 {
636 return expression{x} + std::move(ex);
637 }
638
639 #endif
640
operator -(expression ex,double x)641 expression operator-(expression ex, double x)
642 {
643 return std::move(ex) - expression{x};
644 }
645
operator -(expression ex,long double x)646 expression operator-(expression ex, long double x)
647 {
648 return std::move(ex) - expression{x};
649 }
650
651 #if defined(HEYOKA_HAVE_REAL128)
652
operator -(expression ex,mppp::real128 x)653 expression operator-(expression ex, mppp::real128 x)
654 {
655 return std::move(ex) - expression{x};
656 }
657
658 #endif
659
operator -(double x,expression ex)660 expression operator-(double x, expression ex)
661 {
662 return expression{x} - std::move(ex);
663 }
664
operator -(long double x,expression ex)665 expression operator-(long double x, expression ex)
666 {
667 return expression{x} - std::move(ex);
668 }
669
670 #if defined(HEYOKA_HAVE_REAL128)
671
operator -(mppp::real128 x,expression ex)672 expression operator-(mppp::real128 x, expression ex)
673 {
674 return expression{x} - std::move(ex);
675 }
676
677 #endif
678
operator *(expression ex,double x)679 expression operator*(expression ex, double x)
680 {
681 return std::move(ex) * expression{x};
682 }
683
operator *(expression ex,long double x)684 expression operator*(expression ex, long double x)
685 {
686 return std::move(ex) * expression{x};
687 }
688
689 #if defined(HEYOKA_HAVE_REAL128)
690
operator *(expression ex,mppp::real128 x)691 expression operator*(expression ex, mppp::real128 x)
692 {
693 return std::move(ex) * expression{x};
694 }
695
696 #endif
697
operator *(double x,expression ex)698 expression operator*(double x, expression ex)
699 {
700 return expression{x} * std::move(ex);
701 }
702
operator *(long double x,expression ex)703 expression operator*(long double x, expression ex)
704 {
705 return expression{x} * std::move(ex);
706 }
707
708 #if defined(HEYOKA_HAVE_REAL128)
709
operator *(mppp::real128 x,expression ex)710 expression operator*(mppp::real128 x, expression ex)
711 {
712 return expression{x} * std::move(ex);
713 }
714
715 #endif
716
operator /(expression ex,double x)717 expression operator/(expression ex, double x)
718 {
719 return std::move(ex) / expression{x};
720 }
721
operator /(expression ex,long double x)722 expression operator/(expression ex, long double x)
723 {
724 return std::move(ex) / expression{x};
725 }
726
727 #if defined(HEYOKA_HAVE_REAL128)
728
operator /(expression ex,mppp::real128 x)729 expression operator/(expression ex, mppp::real128 x)
730 {
731 return std::move(ex) / expression{x};
732 }
733
734 #endif
735
operator /(double x,expression ex)736 expression operator/(double x, expression ex)
737 {
738 return expression{x} / std::move(ex);
739 }
740
operator /(long double x,expression ex)741 expression operator/(long double x, expression ex)
742 {
743 return expression{x} / std::move(ex);
744 }
745
746 #if defined(HEYOKA_HAVE_REAL128)
747
operator /(mppp::real128 x,expression ex)748 expression operator/(mppp::real128 x, expression ex)
749 {
750 return expression{x} / std::move(ex);
751 }
752
753 #endif
754
operator +=(expression & x,expression e)755 expression &operator+=(expression &x, expression e)
756 {
757 return x = std::move(x) + std::move(e);
758 }
759
operator -=(expression & x,expression e)760 expression &operator-=(expression &x, expression e)
761 {
762 return x = std::move(x) - std::move(e);
763 }
764
operator *=(expression & x,expression e)765 expression &operator*=(expression &x, expression e)
766 {
767 return x = std::move(x) * std::move(e);
768 }
769
operator /=(expression & x,expression e)770 expression &operator/=(expression &x, expression e)
771 {
772 return x = std::move(x) / std::move(e);
773 }
774
operator +=(expression & ex,double x)775 expression &operator+=(expression &ex, double x)
776 {
777 return ex += expression{x};
778 }
779
operator +=(expression & ex,long double x)780 expression &operator+=(expression &ex, long double x)
781 {
782 return ex += expression{x};
783 }
784
785 #if defined(HEYOKA_HAVE_REAL128)
786
operator +=(expression & ex,mppp::real128 x)787 expression &operator+=(expression &ex, mppp::real128 x)
788 {
789 return ex += expression{x};
790 }
791
792 #endif
793
operator -=(expression & ex,double x)794 expression &operator-=(expression &ex, double x)
795 {
796 return ex -= expression{x};
797 }
798
operator -=(expression & ex,long double x)799 expression &operator-=(expression &ex, long double x)
800 {
801 return ex -= expression{x};
802 }
803
804 #if defined(HEYOKA_HAVE_REAL128)
805
operator -=(expression & ex,mppp::real128 x)806 expression &operator-=(expression &ex, mppp::real128 x)
807 {
808 return ex -= expression{x};
809 }
810
811 #endif
812
operator *=(expression & ex,double x)813 expression &operator*=(expression &ex, double x)
814 {
815 return ex *= expression{x};
816 }
817
operator *=(expression & ex,long double x)818 expression &operator*=(expression &ex, long double x)
819 {
820 return ex *= expression{x};
821 }
822
823 #if defined(HEYOKA_HAVE_REAL128)
824
operator *=(expression & ex,mppp::real128 x)825 expression &operator*=(expression &ex, mppp::real128 x)
826 {
827 return ex *= expression{x};
828 }
829
830 #endif
831
operator /=(expression & ex,double x)832 expression &operator/=(expression &ex, double x)
833 {
834 return ex /= expression{x};
835 }
836
operator /=(expression & ex,long double x)837 expression &operator/=(expression &ex, long double x)
838 {
839 return ex /= expression{x};
840 }
841
842 #if defined(HEYOKA_HAVE_REAL128)
843
operator /=(expression & ex,mppp::real128 x)844 expression &operator/=(expression &ex, mppp::real128 x)
845 {
846 return ex /= expression{x};
847 }
848
849 #endif
850
operator ==(const expression & e1,const expression & e2)851 bool operator==(const expression &e1, const expression &e2)
852 {
853 auto visitor = [](const auto &v1, const auto &v2) {
854 using type1 = detail::uncvref_t<decltype(v1)>;
855 using type2 = detail::uncvref_t<decltype(v2)>;
856
857 if constexpr (std::is_same_v<type1, type2>) {
858 return v1 == v2;
859 } else {
860 return false;
861 }
862 };
863
864 return std::visit(visitor, e1.value(), e2.value());
865 }
866
operator !=(const expression & e1,const expression & e2)867 bool operator!=(const expression &e1, const expression &e2)
868 {
869 return !(e1 == e2);
870 }
871
872 namespace detail
873 {
874
875 namespace
876 {
877
get_n_nodes(std::unordered_map<const void *,std::size_t> & func_map,const expression & e)878 std::size_t get_n_nodes(std::unordered_map<const void *, std::size_t> &func_map, const expression &e)
879 {
880 return std::visit(
881 [&func_map](const auto &arg) -> std::size_t {
882 if constexpr (std::is_same_v<func, detail::uncvref_t<decltype(arg)>>) {
883 const auto f_id = arg.get_ptr();
884
885 if (auto it = func_map.find(f_id); it != func_map.end()) {
886 // We already computed the number of nodes for the current
887 // function, return it.
888 return it->second;
889 }
890
891 std::size_t retval = 1;
892 for (const auto &ex : arg.args()) {
893 retval += get_n_nodes(func_map, ex);
894 }
895
896 // Store the number of nodes for the current function
897 // in the cache.
898 [[maybe_unused]] const auto [_, flag] = func_map.insert(std::pair{f_id, retval});
899 // NOTE: an expression cannot contain itself.
900 assert(flag);
901
902 return retval;
903 } else {
904 return 1;
905 }
906 },
907 e.value());
908 }
909
910 } // namespace
911
912 } // namespace detail
913
get_n_nodes(const expression & e)914 std::size_t get_n_nodes(const expression &e)
915 {
916 std::unordered_map<const void *, std::size_t> func_map;
917
918 return detail::get_n_nodes(func_map, e);
919 }
920
921 namespace detail
922 {
923
diff(std::unordered_map<const void *,expression> & func_map,const expression & e,const std::string & s)924 expression diff(std::unordered_map<const void *, expression> &func_map, const expression &e, const std::string &s)
925 {
926 return std::visit(
927 [&func_map, &s](const auto &arg) {
928 using type = detail::uncvref_t<decltype(arg)>;
929
930 if constexpr (std::is_same_v<type, number>) {
931 return std::visit([](const auto &v) { return expression{number{detail::uncvref_t<decltype(v)>(0)}}; },
932 arg.value());
933 } else if constexpr (std::is_same_v<type, param>) {
934 // NOTE: if we ever implement single-precision support,
935 // this should be probably changed into 0_flt (i.e., the lowest
936 // precision numerical type), so that it does not trigger
937 // type promotions in numerical constants. Other similar
938 // occurrences as well (e.g., diff for variable).
939 return 0_dbl;
940 } else if constexpr (std::is_same_v<type, variable>) {
941 if (s == arg.name()) {
942 return 1_dbl;
943 } else {
944 return 0_dbl;
945 }
946 } else {
947 const auto f_id = arg.get_ptr();
948
949 if (auto it = func_map.find(f_id); it != func_map.end()) {
950 // We already performed diff on the current function,
951 // fetch the result from the cache.
952 return it->second;
953 }
954
955 auto ret = arg.diff(func_map, s);
956
957 // Put the return value in the cache.
958 [[maybe_unused]] const auto [_, flag] = func_map.insert(std::pair{f_id, ret});
959 // NOTE: an expression cannot contain itself.
960 assert(flag);
961
962 return ret;
963 }
964 },
965 e.value());
966 }
967
diff(std::unordered_map<const void *,expression> & func_map,const expression & e,const param & p)968 expression diff(std::unordered_map<const void *, expression> &func_map, const expression &e, const param &p)
969 {
970 return std::visit(
971 [&func_map, &p](const auto &arg) {
972 using type = detail::uncvref_t<decltype(arg)>;
973
974 if constexpr (std::is_same_v<type, number>) {
975 return std::visit([](const auto &v) { return expression{number{detail::uncvref_t<decltype(v)>(0)}}; },
976 arg.value());
977 } else if constexpr (std::is_same_v<type, param>) {
978 if (p.idx() == arg.idx()) {
979 return 1_dbl;
980 } else {
981 return 0_dbl;
982 }
983 } else if constexpr (std::is_same_v<type, variable>) {
984 return 0_dbl;
985 } else {
986 const auto f_id = arg.get_ptr();
987
988 if (auto it = func_map.find(f_id); it != func_map.end()) {
989 // We already performed diff on the current function,
990 // fetch the result from the cache.
991 return it->second;
992 }
993
994 auto ret = arg.diff(func_map, p);
995
996 // Put the return value in the cache.
997 [[maybe_unused]] const auto [_, flag] = func_map.insert(std::pair{f_id, ret});
998 // NOTE: an expression cannot contain itself.
999 assert(flag);
1000
1001 return ret;
1002 }
1003 },
1004 e.value());
1005 }
1006
1007 } // namespace detail
1008
diff(const expression & e,const std::string & s)1009 expression diff(const expression &e, const std::string &s)
1010 {
1011 std::unordered_map<const void *, expression> func_map;
1012
1013 return detail::diff(func_map, e, s);
1014 }
1015
diff(const expression & e,const param & p)1016 expression diff(const expression &e, const param &p)
1017 {
1018 std::unordered_map<const void *, expression> func_map;
1019
1020 return detail::diff(func_map, e, p);
1021 }
1022
diff(const expression & e,const expression & x)1023 expression diff(const expression &e, const expression &x)
1024 {
1025 return std::visit(
1026 [&e](const auto &v) -> expression {
1027 if constexpr (std::is_same_v<detail::uncvref_t<decltype(v)>, variable>) {
1028 return diff(e, v.name());
1029 } else if constexpr (std::is_same_v<detail::uncvref_t<decltype(v)>, param>) {
1030 return diff(e, v);
1031 } else {
1032 throw std::invalid_argument(
1033 "Derivatives are currently supported only with respect to variables and parameters");
1034 }
1035 },
1036 x.value());
1037 }
1038
1039 namespace detail
1040 {
1041
1042 namespace
1043 {
1044
1045 // NOTE: an in-place API would perform better.
subs(std::unordered_map<const void *,expression> & func_map,const expression & ex,const std::unordered_map<std::string,expression> & smap)1046 expression subs(std::unordered_map<const void *, expression> &func_map, const expression &ex,
1047 const std::unordered_map<std::string, expression> &smap)
1048 {
1049 return std::visit(
1050 [&func_map, &smap](const auto &arg) {
1051 using type = detail::uncvref_t<decltype(arg)>;
1052
1053 if constexpr (std::is_same_v<type, number> || std::is_same_v<type, param>) {
1054 return expression{arg};
1055 } else if constexpr (std::is_same_v<type, variable>) {
1056 if (auto it = smap.find(arg.name()); it == smap.end()) {
1057 return expression{arg};
1058 } else {
1059 return it->second;
1060 }
1061 } else {
1062 const auto f_id = arg.get_ptr();
1063
1064 if (auto it = func_map.find(f_id); it != func_map.end()) {
1065 // We already performed substitution on the current function,
1066 // fetch the result from the cache.
1067 return it->second;
1068 }
1069
1070 // NOTE: this creates a separate instance of arg, but its
1071 // arguments are shallow-copied.
1072 auto tmp = arg.copy();
1073
1074 for (auto [b, e] = tmp.get_mutable_args_it(); b != e; ++b) {
1075 *b = subs(func_map, *b, smap);
1076 }
1077
1078 // Put the return value in the cache.
1079 auto ret = expression{std::move(tmp)};
1080 [[maybe_unused]] const auto [_, flag] = func_map.insert(std::pair{f_id, ret});
1081 // NOTE: an expression cannot contain itself.
1082 assert(flag);
1083
1084 return ret;
1085 }
1086 },
1087 ex.value());
1088 }
1089
1090 } // namespace
1091
1092 } // namespace detail
1093
subs(const expression & e,const std::unordered_map<std::string,expression> & smap)1094 expression subs(const expression &e, const std::unordered_map<std::string, expression> &smap)
1095 {
1096 std::unordered_map<const void *, expression> func_map;
1097
1098 return detail::subs(func_map, e, smap);
1099 }
1100
1101 namespace detail
1102 {
1103
1104 namespace
1105 {
1106
1107 // Pairwise reduction of a vector of expressions.
1108 template <typename F>
pairwise_reduce(const F & func,std::vector<expression> list)1109 expression pairwise_reduce(const F &func, std::vector<expression> list)
1110 {
1111 assert(!list.empty());
1112
1113 // LCOV_EXCL_START
1114 if (list.size() == std::numeric_limits<decltype(list.size())>::max()) {
1115 throw std::overflow_error("Overflow detected in pairwise_reduce()");
1116 }
1117 // LCOV_EXCL_STOP
1118
1119 while (list.size() != 1u) {
1120 const auto cur_size = list.size();
1121
1122 // Init the new list. The size will be halved, +1 if the
1123 // current size is odd.
1124 const auto next_size = cur_size / 2u + cur_size % 2u;
1125 std::vector<expression> new_list(next_size);
1126
1127 tbb::parallel_for(tbb::blocked_range<decltype(new_list.size())>(0, new_list.size()),
1128 [&list, &new_list, cur_size, &func](const auto &r) {
1129 for (auto i = r.begin(); i != r.end(); ++i) {
1130 if (i * 2u == cur_size - 1u) {
1131 // list has an odd size, and we are at the last element of list.
1132 // Just move it to new_list.
1133 new_list[i] = std::move(list.back());
1134 } else {
1135 new_list[i] = func(std::move(list[i * 2u]), std::move(list[i * 2u + 1u]));
1136 }
1137 }
1138 });
1139
1140 new_list.swap(list);
1141 }
1142
1143 return std::move(list[0]);
1144 }
1145
1146 } // namespace
1147
1148 } // namespace detail
1149
1150 // Pairwise product.
pairwise_prod(std::vector<expression> prod)1151 expression pairwise_prod(std::vector<expression> prod)
1152 {
1153 if (prod.empty()) {
1154 return 1_dbl;
1155 }
1156
1157 return detail::pairwise_reduce([](expression &&a, expression &&b) { return std::move(a) * std::move(b); },
1158 std::move(prod));
1159 }
1160
eval_dbl(const expression & e,const std::unordered_map<std::string,double> & map,const std::vector<double> & pars)1161 double eval_dbl(const expression &e, const std::unordered_map<std::string, double> &map,
1162 const std::vector<double> &pars)
1163 {
1164 return std::visit([&](const auto &arg) { return eval_dbl(arg, map, pars); }, e.value());
1165 }
1166
eval_ldbl(const expression & e,const std::unordered_map<std::string,long double> & map,const std::vector<long double> & pars)1167 long double eval_ldbl(const expression &e, const std::unordered_map<std::string, long double> &map,
1168 const std::vector<long double> &pars)
1169 {
1170 return std::visit([&](const auto &arg) { return eval_ldbl(arg, map, pars); }, e.value());
1171 }
1172
1173 #if defined(HEYOKA_HAVE_REAL128)
eval_f128(const expression & e,const std::unordered_map<std::string,mppp::real128> & map,const std::vector<mppp::real128> & pars)1174 mppp::real128 eval_f128(const expression &e, const std::unordered_map<std::string, mppp::real128> &map,
1175 const std::vector<mppp::real128> &pars)
1176 {
1177 return std::visit([&](const auto &arg) { return eval_f128(arg, map, pars); }, e.value());
1178 }
1179 #endif
1180
eval_batch_dbl(std::vector<double> & retval,const expression & e,const std::unordered_map<std::string,std::vector<double>> & map,const std::vector<double> & pars)1181 void eval_batch_dbl(std::vector<double> &retval, const expression &e,
1182 const std::unordered_map<std::string, std::vector<double>> &map, const std::vector<double> &pars)
1183 {
1184 std::visit([&](const auto &arg) { eval_batch_dbl(retval, arg, map, pars); }, e.value());
1185 }
1186
compute_connections(const expression & e)1187 std::vector<std::vector<std::size_t>> compute_connections(const expression &e)
1188 {
1189 std::vector<std::vector<std::size_t>> node_connections;
1190 std::size_t node_counter = 0u;
1191 update_connections(node_connections, e, node_counter);
1192 return node_connections;
1193 }
1194
update_connections(std::vector<std::vector<std::size_t>> & node_connections,const expression & e,std::size_t & node_counter)1195 void update_connections(std::vector<std::vector<std::size_t>> &node_connections, const expression &e,
1196 std::size_t &node_counter)
1197 {
1198 std::visit([&node_connections,
1199 &node_counter](const auto &arg) { update_connections(node_connections, arg, node_counter); },
1200 e.value());
1201 }
1202
compute_node_values_dbl(const expression & e,const std::unordered_map<std::string,double> & map,const std::vector<std::vector<std::size_t>> & node_connections)1203 std::vector<double> compute_node_values_dbl(const expression &e, const std::unordered_map<std::string, double> &map,
1204 const std::vector<std::vector<std::size_t>> &node_connections)
1205 {
1206 std::vector<double> node_values(node_connections.size());
1207 std::size_t node_counter = 0u;
1208 update_node_values_dbl(node_values, e, map, node_connections, node_counter);
1209 return node_values;
1210 }
1211
update_node_values_dbl(std::vector<double> & node_values,const expression & e,const std::unordered_map<std::string,double> & map,const std::vector<std::vector<std::size_t>> & node_connections,std::size_t & node_counter)1212 void update_node_values_dbl(std::vector<double> &node_values, const expression &e,
1213 const std::unordered_map<std::string, double> &map,
1214 const std::vector<std::vector<std::size_t>> &node_connections, std::size_t &node_counter)
1215 {
1216 std::visit([&map, &node_values, &node_connections, &node_counter](
1217 const auto &arg) { update_node_values_dbl(node_values, arg, map, node_connections, node_counter); },
1218 e.value());
1219 }
1220
compute_grad_dbl(const expression & e,const std::unordered_map<std::string,double> & map,const std::vector<std::vector<std::size_t>> & node_connections)1221 std::unordered_map<std::string, double> compute_grad_dbl(const expression &e,
1222 const std::unordered_map<std::string, double> &map,
1223 const std::vector<std::vector<std::size_t>> &node_connections)
1224 {
1225 std::unordered_map<std::string, double> grad;
1226 auto node_values = compute_node_values_dbl(e, map, node_connections);
1227 std::size_t node_counter = 0u;
1228 update_grad_dbl(grad, e, map, node_values, node_connections, node_counter);
1229 return grad;
1230 }
1231
update_grad_dbl(std::unordered_map<std::string,double> & grad,const expression & e,const std::unordered_map<std::string,double> & map,const std::vector<double> & node_values,const std::vector<std::vector<std::size_t>> & node_connections,std::size_t & node_counter,double acc)1232 void update_grad_dbl(std::unordered_map<std::string, double> &grad, const expression &e,
1233 const std::unordered_map<std::string, double> &map, const std::vector<double> &node_values,
1234 const std::vector<std::vector<std::size_t>> &node_connections, std::size_t &node_counter,
1235 double acc)
1236 {
1237 std::visit(
1238 [&map, &grad, &node_values, &node_connections, &node_counter, &acc](const auto &arg) {
1239 update_grad_dbl(grad, arg, map, node_values, node_connections, node_counter, acc);
1240 },
1241 e.value());
1242 }
1243
1244 namespace detail
1245 {
1246
taylor_decompose(std::unordered_map<const void *,taylor_dc_t::size_type> & func_map,const expression & ex,taylor_dc_t & dc)1247 taylor_dc_t::size_type taylor_decompose(std::unordered_map<const void *, taylor_dc_t::size_type> &func_map,
1248 const expression &ex, taylor_dc_t &dc)
1249 {
1250 if (auto fptr = std::get_if<func>(&ex.value())) {
1251 return fptr->taylor_decompose(func_map, dc);
1252 } else {
1253 return 0;
1254 }
1255 }
1256
1257 } // namespace detail
1258
1259 // Decompose ex into dc. The return value is the index, in dc,
1260 // which corresponds to the decomposed version of ex.
1261 // If the return value is zero, ex was not decomposed.
taylor_decompose(const expression & ex,taylor_dc_t & dc)1262 taylor_dc_t::size_type taylor_decompose(const expression &ex, taylor_dc_t &dc)
1263 {
1264 std::unordered_map<const void *, taylor_dc_t::size_type> func_map;
1265
1266 return detail::taylor_decompose(func_map, ex, dc);
1267 }
1268
1269 namespace detail
1270 {
1271
1272 namespace
1273 {
1274
1275 template <typename T>
taylor_diff_impl(llvm_state & s,const expression & ex,const std::vector<std::uint32_t> & deps,const std::vector<llvm::Value * > & arr,llvm::Value * par_ptr,llvm::Value * time_ptr,std::uint32_t n_uvars,std::uint32_t order,std::uint32_t idx,std::uint32_t batch_size,bool high_accuracy)1276 llvm::Value *taylor_diff_impl(llvm_state &s, const expression &ex, const std::vector<std::uint32_t> &deps,
1277 const std::vector<llvm::Value *> &arr, llvm::Value *par_ptr, llvm::Value *time_ptr,
1278 std::uint32_t n_uvars, std::uint32_t order, std::uint32_t idx, std::uint32_t batch_size,
1279 bool high_accuracy)
1280 {
1281 if (auto fptr = std::get_if<func>(&ex.value())) {
1282 if constexpr (std::is_same_v<T, double>) {
1283 return fptr->taylor_diff_dbl(s, deps, arr, par_ptr, time_ptr, n_uvars, order, idx, batch_size,
1284 high_accuracy);
1285 } else if constexpr (std::is_same_v<T, long double>) {
1286 return fptr->taylor_diff_ldbl(s, deps, arr, par_ptr, time_ptr, n_uvars, order, idx, batch_size,
1287 high_accuracy);
1288 #if defined(HEYOKA_HAVE_REAL128)
1289 } else if constexpr (std::is_same_v<T, mppp::real128>) {
1290 return fptr->taylor_diff_f128(s, deps, arr, par_ptr, time_ptr, n_uvars, order, idx, batch_size,
1291 high_accuracy);
1292 #endif
1293 } else {
1294 static_assert(detail::always_false_v<T>, "Unhandled type.");
1295 }
1296 } else {
1297 // LCOV_EXCL_START
1298 throw std::invalid_argument("Taylor derivatives can be computed only for functions");
1299 // LCOV_EXCL_STOP
1300 }
1301 }
1302
1303 } // namespace
1304
1305 } // namespace detail
1306
taylor_diff_dbl(llvm_state & s,const expression & ex,const std::vector<std::uint32_t> & deps,const std::vector<llvm::Value * > & arr,llvm::Value * par_ptr,llvm::Value * time_ptr,std::uint32_t n_uvars,std::uint32_t order,std::uint32_t idx,std::uint32_t batch_size,bool high_accuracy)1307 llvm::Value *taylor_diff_dbl(llvm_state &s, const expression &ex, const std::vector<std::uint32_t> &deps,
1308 const std::vector<llvm::Value *> &arr, llvm::Value *par_ptr, llvm::Value *time_ptr,
1309 std::uint32_t n_uvars, std::uint32_t order, std::uint32_t idx, std::uint32_t batch_size,
1310 bool high_accuracy)
1311
1312 {
1313 return detail::taylor_diff_impl<double>(s, ex, deps, arr, par_ptr, time_ptr, n_uvars, order, idx, batch_size,
1314 high_accuracy);
1315 }
1316
taylor_diff_ldbl(llvm_state & s,const expression & ex,const std::vector<std::uint32_t> & deps,const std::vector<llvm::Value * > & arr,llvm::Value * par_ptr,llvm::Value * time_ptr,std::uint32_t n_uvars,std::uint32_t order,std::uint32_t idx,std::uint32_t batch_size,bool high_accuracy)1317 llvm::Value *taylor_diff_ldbl(llvm_state &s, const expression &ex, const std::vector<std::uint32_t> &deps,
1318 const std::vector<llvm::Value *> &arr, llvm::Value *par_ptr, llvm::Value *time_ptr,
1319 std::uint32_t n_uvars, std::uint32_t order, std::uint32_t idx, std::uint32_t batch_size,
1320 bool high_accuracy)
1321 {
1322 return detail::taylor_diff_impl<long double>(s, ex, deps, arr, par_ptr, time_ptr, n_uvars, order, idx, batch_size,
1323 high_accuracy);
1324 }
1325
1326 #if defined(HEYOKA_HAVE_REAL128)
1327
taylor_diff_f128(llvm_state & s,const expression & ex,const std::vector<std::uint32_t> & deps,const std::vector<llvm::Value * > & arr,llvm::Value * par_ptr,llvm::Value * time_ptr,std::uint32_t n_uvars,std::uint32_t order,std::uint32_t idx,std::uint32_t batch_size,bool high_accuracy)1328 llvm::Value *taylor_diff_f128(llvm_state &s, const expression &ex, const std::vector<std::uint32_t> &deps,
1329 const std::vector<llvm::Value *> &arr, llvm::Value *par_ptr, llvm::Value *time_ptr,
1330 std::uint32_t n_uvars, std::uint32_t order, std::uint32_t idx, std::uint32_t batch_size,
1331 bool high_accuracy)
1332 {
1333 return detail::taylor_diff_impl<mppp::real128>(s, ex, deps, arr, par_ptr, time_ptr, n_uvars, order, idx, batch_size,
1334 high_accuracy);
1335 }
1336
1337 #endif
1338
1339 namespace detail
1340 {
1341
1342 namespace
1343 {
1344
1345 template <typename T>
taylor_c_diff_func_impl(llvm_state & s,const expression & ex,std::uint32_t n_uvars,std::uint32_t batch_size,bool high_accuracy)1346 llvm::Function *taylor_c_diff_func_impl(llvm_state &s, const expression &ex, std::uint32_t n_uvars,
1347 std::uint32_t batch_size, bool high_accuracy)
1348 {
1349 if (auto fptr = std::get_if<func>(&ex.value())) {
1350 if constexpr (std::is_same_v<T, double>) {
1351 return fptr->taylor_c_diff_func_dbl(s, n_uvars, batch_size, high_accuracy);
1352 } else if constexpr (std::is_same_v<T, long double>) {
1353 return fptr->taylor_c_diff_func_ldbl(s, n_uvars, batch_size, high_accuracy);
1354 #if defined(HEYOKA_HAVE_REAL128)
1355 } else if constexpr (std::is_same_v<T, mppp::real128>) {
1356 return fptr->taylor_c_diff_func_f128(s, n_uvars, batch_size, high_accuracy);
1357 #endif
1358 } else {
1359 static_assert(detail::always_false_v<T>, "Unhandled type.");
1360 }
1361 } else {
1362 // LCOV_EXCL_START
1363 throw std::invalid_argument("Taylor derivatives in compact mode can be computed only for functions");
1364 // LCOV_EXCL_STOP
1365 }
1366 }
1367
1368 } // namespace
1369
1370 } // namespace detail
1371
taylor_c_diff_func_dbl(llvm_state & s,const expression & ex,std::uint32_t n_uvars,std::uint32_t batch_size,bool high_accuracy)1372 llvm::Function *taylor_c_diff_func_dbl(llvm_state &s, const expression &ex, std::uint32_t n_uvars,
1373 std::uint32_t batch_size, bool high_accuracy)
1374 {
1375 return detail::taylor_c_diff_func_impl<double>(s, ex, n_uvars, batch_size, high_accuracy);
1376 }
1377
taylor_c_diff_func_ldbl(llvm_state & s,const expression & ex,std::uint32_t n_uvars,std::uint32_t batch_size,bool high_accuracy)1378 llvm::Function *taylor_c_diff_func_ldbl(llvm_state &s, const expression &ex, std::uint32_t n_uvars,
1379 std::uint32_t batch_size, bool high_accuracy)
1380 {
1381 return detail::taylor_c_diff_func_impl<long double>(s, ex, n_uvars, batch_size, high_accuracy);
1382 }
1383
1384 #if defined(HEYOKA_HAVE_REAL128)
1385
taylor_c_diff_func_f128(llvm_state & s,const expression & ex,std::uint32_t n_uvars,std::uint32_t batch_size,bool high_accuracy)1386 llvm::Function *taylor_c_diff_func_f128(llvm_state &s, const expression &ex, std::uint32_t n_uvars,
1387 std::uint32_t batch_size, bool high_accuracy)
1388 {
1389 return detail::taylor_c_diff_func_impl<mppp::real128>(s, ex, n_uvars, batch_size, high_accuracy);
1390 }
1391
1392 #endif
1393
1394 namespace detail
1395 {
1396
1397 // Helper to detect if ex is an integral number.
is_integral(const expression & ex)1398 bool is_integral(const expression &ex)
1399 {
1400 return std::visit(
1401 [](const auto &v) {
1402 using type = detail::uncvref_t<decltype(v)>;
1403
1404 if constexpr (std::is_same_v<type, number>) {
1405 return std::visit(
1406 [](const auto &x) {
1407 using std::trunc;
1408 using std::isfinite;
1409
1410 return isfinite(x) && x == trunc(x);
1411 },
1412 v.value());
1413 } else {
1414 // Not a number.
1415 return false;
1416 }
1417 },
1418 ex.value());
1419 }
1420
1421 // Helper to detect if ex is a number in the form n / 2,
1422 // where n is an odd integral value.
is_odd_integral_half(const expression & ex)1423 bool is_odd_integral_half(const expression &ex)
1424 {
1425 return std::visit(
1426 [](const auto &v) {
1427 using type = detail::uncvref_t<decltype(v)>;
1428
1429 if constexpr (std::is_same_v<type, number>) {
1430 return std::visit(
1431 [](const auto &x) {
1432 using std::trunc;
1433 using std::isfinite;
1434
1435 if (!isfinite(x) || x == trunc(x)) {
1436 // x is not finite, or it is already
1437 // an integral value.
1438 return false;
1439 }
1440
1441 const auto y = 2 * x;
1442
1443 return isfinite(y) && y == trunc(y);
1444 },
1445 v.value());
1446 } else {
1447 // Not a number.
1448 return false;
1449 }
1450 },
1451 ex.value());
1452 }
1453
operator [](std::uint32_t idx) const1454 expression par_impl::operator[](std::uint32_t idx) const
1455 {
1456 return expression{param{idx}};
1457 }
1458
1459 } // namespace detail
1460
1461 namespace detail
1462 {
1463
1464 namespace
1465 {
1466
get_param_size(std::unordered_set<const void * > & func_set,const expression & ex)1467 std::uint32_t get_param_size(std::unordered_set<const void *> &func_set, const expression &ex)
1468 {
1469 std::uint32_t retval = 0;
1470
1471 std::visit(
1472 [&retval, &func_set](const auto &v) {
1473 using type = detail::uncvref_t<decltype(v)>;
1474
1475 if constexpr (std::is_same_v<type, param>) {
1476 if (v.idx() == std::numeric_limits<std::uint32_t>::max()) {
1477 throw std::overflow_error("Overflow dected in get_n_param()");
1478 }
1479
1480 retval = std::max(static_cast<std::uint32_t>(v.idx() + 1u), retval);
1481 } else if constexpr (std::is_same_v<type, func>) {
1482 const auto f_id = v.get_ptr();
1483
1484 if (auto it = func_set.find(f_id); it != func_set.end()) {
1485 // We already computed the number of params for the current
1486 // function, exit.
1487 return;
1488 }
1489
1490 for (const auto &a : v.args()) {
1491 retval = std::max(get_param_size(func_set, a), retval);
1492 }
1493
1494 // Update the cache.
1495 [[maybe_unused]] const auto [_, flag] = func_set.insert(f_id);
1496 // NOTE: an expression cannot contain itself.
1497 assert(flag);
1498 }
1499 },
1500 ex.value());
1501
1502 return retval;
1503 }
1504
1505 } // namespace
1506
1507 } // namespace detail
1508
1509 // Determine the size of the parameter vector from the highest
1510 // param index appearing in an expression. If the return value
1511 // is zero, no params appear in the expression.
get_param_size(const expression & ex)1512 std::uint32_t get_param_size(const expression &ex)
1513 {
1514 std::unordered_set<const void *> func_set;
1515
1516 return detail::get_param_size(func_set, ex);
1517 }
1518
1519 namespace detail
1520 {
1521
1522 namespace
1523 {
1524
has_time(std::unordered_set<const void * > & func_set,const expression & ex)1525 bool has_time(std::unordered_set<const void *> &func_set, const expression &ex)
1526 {
1527 // If the expression itself is a time function or a tpoly,
1528 // return true.
1529 if (detail::is_time(ex) || detail::is_tpoly(ex)) {
1530 return true;
1531 }
1532
1533 // Otherwise:
1534 // - if ex is a function, check if any of its arguments
1535 // is time-dependent,
1536 // - otherwise, return false.
1537 return std::visit(
1538 [&func_set](const auto &v) {
1539 using type = detail::uncvref_t<decltype(v)>;
1540
1541 if constexpr (std::is_same_v<type, func>) {
1542 const auto f_id = v.get_ptr();
1543
1544 if (auto it = func_set.find(f_id); it != func_set.end()) {
1545 // We already determined if this function contains time,
1546 // return false (if the function does contain time, the first
1547 // time it was encountered we returned true and we could not
1548 // possibly end up here).
1549 return false;
1550 }
1551
1552 // Update the cache.
1553 // NOTE: do it earlier than usual in order to avoid having
1554 // to repeat this code twice for the two paths below.
1555 func_set.insert(f_id);
1556
1557 for (const auto &a : v.args()) {
1558 if (has_time(func_set, a)) {
1559 return true;
1560 }
1561 }
1562 }
1563
1564 return false;
1565 },
1566 ex.value());
1567 }
1568
1569 } // namespace
1570
1571 } // namespace detail
1572
1573 // Determine if an expression is time-dependent.
has_time(const expression & ex)1574 bool has_time(const expression &ex)
1575 {
1576 std::unordered_set<const void *> func_set;
1577
1578 return detail::has_time(func_set, ex);
1579 }
1580
1581 } // namespace heyoka
1582