1 // Copyright 2020, 2021 Francesco Biscani (bluescarni@gmail.com), Dario Izzo (dario.izzo@gmail.com)
2 //
3 // This file is part of the heyoka library.
4 //
5 // This Source Code Form is subject to the terms of the Mozilla
6 // Public License v. 2.0. If a copy of the MPL was not distributed
7 // with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
8 
9 #include <cstddef>
10 #include <initializer_list>
11 #include <iostream>
12 #include <sstream>
13 #include <stdexcept>
14 #include <string>
15 #include <type_traits>
16 #include <typeinfo>
17 #include <unordered_map>
18 #include <utility>
19 #include <variant>
20 #include <vector>
21 
22 #include <heyoka/config.hpp>
23 #include <heyoka/detail/llvm_fwd.hpp>
24 #include <heyoka/exceptions.hpp>
25 #include <heyoka/expression.hpp>
26 #include <heyoka/func.hpp>
27 #include <heyoka/llvm_state.hpp>
28 #include <heyoka/s11n.hpp>
29 #include <heyoka/taylor.hpp>
30 
31 #include "catch.hpp"
32 
33 using namespace heyoka;
34 
35 struct func_00 : func_base {
func_00func_0036     func_00() : func_base("f", {}) {}
func_00func_0037     func_00(const std::string &name) : func_base(name, {}) {}
func_00func_0038     explicit func_00(std::vector<expression> args) : func_base("f", std::move(args)) {}
39 };
40 
41 struct func_01 {
42 };
43 
44 TEST_CASE("func minimal")
45 {
46     using Catch::Matchers::Message;
47 
48     func f(func_00{{"x"_var, "y"_var}});
49     REQUIRE(f.get_type_index() == typeid(func_00));
50     REQUIRE(f.get_name() == "f");
51     REQUIRE(f.args() == std::vector{"x"_var, "y"_var});
52 
53     REQUIRE_THROWS_MATCHES(func{func_00{""}}, std::invalid_argument, Message("Cannot create a function with no name"));
54 
55     llvm_state s;
56     auto fake_val = reinterpret_cast<llvm::Value *>(&s);
57     REQUIRE_THROWS_MATCHES(f.codegen_dbl(s, {fake_val, fake_val}), not_implemented_error,
58                            Message("double codegen is not implemented for the function 'f'"));
59     REQUIRE_THROWS_MATCHES(
60         f.codegen_dbl(s, {nullptr, nullptr}), std::invalid_argument,
61         Message("Null pointer detected in the array of values passed to func::codegen_dbl() for the function 'f'"));
62     REQUIRE_THROWS_MATCHES(
63         f.codegen_ldbl(s, {nullptr, nullptr}), std::invalid_argument,
64         Message("Null pointer detected in the array of values passed to func::codegen_ldbl() for the function 'f'"));
65 #if defined(HEYOKA_HAVE_REAL128)
66     REQUIRE_THROWS_MATCHES(
67         f.codegen_f128(s, {nullptr, nullptr}), std::invalid_argument,
68         Message("Null pointer detected in the array of values passed to func::codegen_f128() for the function 'f'"));
69 #endif
70     std::unordered_map<const void *, expression> func_map;
71     REQUIRE_THROWS_MATCHES(f.diff(func_map, ""), not_implemented_error,
72                            Message("Cannot compute the derivative of the function 'f' with respect to a variable, "
73                                    "because the function does not provide "
74                                    "neither a diff() nor a gradient() member function"));
75     REQUIRE_THROWS_MATCHES(f.diff(func_map, std::get<param>(par[0].value())), not_implemented_error,
76                            Message("Cannot compute the derivative of the function 'f' with respect to a parameter, "
77                                    "because the function does not provide "
78                                    "neither a diff() nor a gradient() member function"));
79     REQUIRE_THROWS_MATCHES(f.eval_dbl({{}}, {}), not_implemented_error,
80                            Message("double eval is not implemented for the function 'f'"));
81     std::vector<double> tmp;
82     REQUIRE_THROWS_MATCHES(f.eval_batch_dbl(tmp, {{}}, {}), not_implemented_error,
83                            Message("double batch eval is not implemented for the function 'f'"));
84     REQUIRE_THROWS_MATCHES(f.eval_num_dbl({1., 1.}), not_implemented_error,
85                            Message("double numerical eval is not implemented for the function 'f'"));
86     REQUIRE_THROWS_MATCHES(
87         f.eval_num_dbl({}), std::invalid_argument,
88         Message("Inconsistent number of arguments supplied to the double numerical evaluation of the function 'f': 2 "
89                 "arguments were expected, but 0 arguments were provided instead"));
90     REQUIRE_THROWS_MATCHES(f.deval_num_dbl({1., 1.}, 0), not_implemented_error,
91                            Message("double numerical eval of the derivative is not implemented for the function 'f'"));
92     REQUIRE_THROWS_MATCHES(f.deval_num_dbl({1.}, 0), std::invalid_argument,
93                            Message("Inconsistent number of arguments supplied to the double numerical evaluation of "
94                                    "the derivative of function 'f': 2 "
95                                    "arguments were expected, but 1 arguments were provided instead"));
96     REQUIRE_THROWS_MATCHES(f.deval_num_dbl({1., 1.}, 2), std::invalid_argument,
97                            Message("Invalid index supplied to the double numerical evaluation of the derivative of "
98                                    "function 'f': index 2 was supplied, but the number of arguments is only 2"));
99 
100     REQUIRE(!std::is_constructible_v<func, func_01>);
101 
102     auto orig_ptr = f.get_ptr();
103     REQUIRE(orig_ptr == static_cast<const func &>(f).get_ptr());
104 
105     auto f2(f);
106     REQUIRE(orig_ptr == f2.get_ptr());
107     REQUIRE(f2.get_type_index() == typeid(func_00));
108     REQUIRE(f2.get_name() == "f");
109     REQUIRE(f2.args() == std::vector{"x"_var, "y"_var});
110 
111     auto f3(std::move(f));
112     REQUIRE(orig_ptr == f3.get_ptr());
113 
114     f = f3;
115     REQUIRE(f.get_ptr() == f3.get_ptr());
116 
117     f = std::move(f3);
118     REQUIRE(f.get_ptr() == orig_ptr);
119 
120     auto a = 0;
121     auto fake_ptr = reinterpret_cast<llvm::Value *>(&a);
122     REQUIRE_THROWS_MATCHES(f.taylor_diff_dbl(s, {}, {nullptr, nullptr}, nullptr, nullptr, 2, 2, 2, 0, false),
123                            std::invalid_argument,
124                            Message("Null par_ptr detected in func::taylor_diff_dbl() for the function 'f'"));
125     REQUIRE_THROWS_MATCHES(f.taylor_diff_dbl(s, {}, {nullptr, nullptr}, fake_ptr, nullptr, 2, 2, 2, 0, false),
126                            std::invalid_argument,
127                            Message("Null time_ptr detected in func::taylor_diff_dbl() for the function 'f'"));
128     REQUIRE_THROWS_MATCHES(f.taylor_diff_dbl(s, {}, {nullptr, nullptr}, fake_ptr, fake_ptr, 2, 2, 2, 0, false),
129                            std::invalid_argument,
130                            Message("Zero batch size detected in func::taylor_diff_dbl() for the function 'f'"));
131     REQUIRE_THROWS_MATCHES(
132         f.taylor_diff_dbl(s, {}, {nullptr, nullptr}, fake_ptr, fake_ptr, 0, 2, 2, 1, false), std::invalid_argument,
133         Message("Zero number of u variables detected in func::taylor_diff_dbl() for the function 'f'"));
134     REQUIRE_THROWS_MATCHES(f.taylor_diff_dbl(s, {}, {nullptr, nullptr}, fake_ptr, fake_ptr, 2, 1, 2, 1, false),
135                            not_implemented_error,
136                            Message("double Taylor diff is not implemented for the function 'f'"));
137 
138     REQUIRE_THROWS_MATCHES(f.taylor_diff_ldbl(s, {}, {nullptr, nullptr}, nullptr, nullptr, 2, 2, 2, 0, false),
139                            std::invalid_argument,
140                            Message("Null par_ptr detected in func::taylor_diff_ldbl() for the function 'f'"));
141     REQUIRE_THROWS_MATCHES(f.taylor_diff_ldbl(s, {}, {nullptr, nullptr}, fake_ptr, nullptr, 2, 2, 2, 0, false),
142                            std::invalid_argument,
143                            Message("Null time_ptr detected in func::taylor_diff_ldbl() for the function 'f'"));
144     REQUIRE_THROWS_MATCHES(f.taylor_diff_ldbl(s, {}, {nullptr, nullptr}, fake_ptr, fake_ptr, 2, 2, 2, 0, false),
145                            std::invalid_argument,
146                            Message("Zero batch size detected in func::taylor_diff_ldbl() for the function 'f'"));
147     REQUIRE_THROWS_MATCHES(
148         f.taylor_diff_ldbl(s, {}, {nullptr, nullptr}, fake_ptr, fake_ptr, 0, 2, 2, 1, false), std::invalid_argument,
149         Message("Zero number of u variables detected in func::taylor_diff_ldbl() for the function 'f'"));
150     REQUIRE_THROWS_MATCHES(f.taylor_diff_ldbl(s, {}, {nullptr, nullptr}, fake_ptr, fake_ptr, 2, 1, 2, 1, false),
151                            not_implemented_error,
152                            Message("long double Taylor diff is not implemented for the function 'f'"));
153 
154 #if defined(HEYOKA_HAVE_REAL128)
155     REQUIRE_THROWS_MATCHES(f.taylor_diff_f128(s, {}, {nullptr, nullptr}, nullptr, nullptr, 2, 2, 2, 0, false),
156                            std::invalid_argument,
157                            Message("Null par_ptr detected in func::taylor_diff_f128() for the function 'f'"));
158     REQUIRE_THROWS_MATCHES(f.taylor_diff_f128(s, {}, {nullptr, nullptr}, fake_ptr, nullptr, 2, 2, 2, 0, false),
159                            std::invalid_argument,
160                            Message("Null time_ptr detected in func::taylor_diff_f128() for the function 'f'"));
161     REQUIRE_THROWS_MATCHES(f.taylor_diff_f128(s, {}, {nullptr, nullptr}, fake_ptr, fake_ptr, 2, 2, 2, 0, false),
162                            std::invalid_argument,
163                            Message("Zero batch size detected in func::taylor_diff_f128() for the function 'f'"));
164     REQUIRE_THROWS_MATCHES(
165         f.taylor_diff_f128(s, {}, {nullptr, nullptr}, fake_ptr, fake_ptr, 0, 2, 2, 1, false), std::invalid_argument,
166         Message("Zero number of u variables detected in func::taylor_diff_f128() for the function 'f'"));
167     REQUIRE_THROWS_MATCHES(f.taylor_diff_f128(s, {}, {nullptr, nullptr}, fake_ptr, fake_ptr, 2, 1, 2, 1, false),
168                            not_implemented_error,
169                            Message("float128 Taylor diff is not implemented for the function 'f'"));
170 #endif
171 
172     REQUIRE_THROWS_MATCHES(f.taylor_c_diff_func_dbl(s, 2, 0, false), std::invalid_argument,
173                            Message("Zero batch size detected in func::taylor_c_diff_func_dbl() for the function 'f'"));
174     REQUIRE_THROWS_MATCHES(
175         f.taylor_c_diff_func_dbl(s, 0, 2, false), std::invalid_argument,
176         Message("Zero number of u variables detected in func::taylor_c_diff_func_dbl() for the function 'f'"));
177     REQUIRE_THROWS_MATCHES(f.taylor_c_diff_func_dbl(s, 2, 1, false), not_implemented_error,
178                            Message("double Taylor diff in compact mode is not implemented for the function 'f'"));
179 
180     REQUIRE_THROWS_MATCHES(f.taylor_c_diff_func_ldbl(s, 2, 0, false), std::invalid_argument,
181                            Message("Zero batch size detected in func::taylor_c_diff_func_ldbl() for the function 'f'"));
182     REQUIRE_THROWS_MATCHES(
183         f.taylor_c_diff_func_ldbl(s, 0, 2, false), std::invalid_argument,
184         Message("Zero number of u variables detected in func::taylor_c_diff_func_ldbl() for the function 'f'"));
185     REQUIRE_THROWS_MATCHES(f.taylor_c_diff_func_ldbl(s, 2, 1, false), not_implemented_error,
186                            Message("long double Taylor diff in compact mode is not implemented for the function 'f'"));
187 
188 #if defined(HEYOKA_HAVE_REAL128)
189     REQUIRE_THROWS_MATCHES(f.taylor_c_diff_func_f128(s, 2, 0, false), std::invalid_argument,
190                            Message("Zero batch size detected in func::taylor_c_diff_func_f128() for the function 'f'"));
191     REQUIRE_THROWS_MATCHES(
192         f.taylor_c_diff_func_f128(s, 0, 2, false), std::invalid_argument,
193         Message("Zero number of u variables detected in func::taylor_c_diff_func_f128() for the function 'f'"));
194     REQUIRE_THROWS_MATCHES(f.taylor_c_diff_func_f128(s, 2, 1, false), not_implemented_error,
195                            Message("float128 Taylor diff in compact mode is not implemented for the function 'f'"));
196 #endif
197 
198     taylor_dc_t dec{{"x"_var, {}}};
199     f = func{func_00{{"x"_var, "y"_var}}};
200     std::unordered_map<const void *, taylor_dc_t::size_type> func_map2;
201     f.taylor_decompose(func_map2, dec);
202 }
203 
204 struct func_02 : func_base {
func_02func_02205     func_02() : func_base("f", {}) {}
func_02func_02206     explicit func_02(std::vector<expression> args) : func_base("f", std::move(args)) {}
207 
codegen_dblfunc_02208     llvm::Value *codegen_dbl(llvm_state &, const std::vector<llvm::Value *> &) const
209     {
210         return nullptr;
211     }
212 };
213 
214 struct func_03 : func_base {
func_03func_03215     func_03() : func_base("f", {}) {}
func_03func_03216     explicit func_03(std::vector<expression> args) : func_base("f", std::move(args)) {}
217 
codegen_ldblfunc_03218     llvm::Value *codegen_ldbl(llvm_state &, const std::vector<llvm::Value *> &) const
219     {
220         return nullptr;
221     }
222 };
223 
224 #if defined(HEYOKA_HAVE_REAL128)
225 
226 struct func_04 : func_base {
func_04func_04227     func_04() : func_base("f", {}) {}
func_04func_04228     explicit func_04(std::vector<expression> args) : func_base("f", std::move(args)) {}
229 
codegen_f128func_04230     llvm::Value *codegen_f128(llvm_state &, const std::vector<llvm::Value *> &) const
231     {
232         return nullptr;
233     }
234 };
235 
236 #endif
237 
238 TEST_CASE("func codegen")
239 {
240     using Catch::Matchers::Message;
241 
242     func f(func_02{{}});
243 
244     llvm_state s;
245     REQUIRE_THROWS_MATCHES(f.codegen_dbl(s, {}), std::invalid_argument,
246                            Message("The double codegen for the function 'f' returned a null pointer"));
247     REQUIRE_THROWS_MATCHES(f.codegen_dbl(s, {nullptr}), std::invalid_argument,
248                            Message("Inconsistent number of arguments supplied to the double codegen for the function "
249                                    "'f': 0 arguments were expected, but 1 arguments were provided instead"));
250     REQUIRE_THROWS_MATCHES(f.codegen_ldbl(s, {}), not_implemented_error,
251                            Message("long double codegen is not implemented for the function 'f'"));
252     REQUIRE_THROWS_MATCHES(
253         f.codegen_ldbl(s, {nullptr}), std::invalid_argument,
254         Message("Inconsistent number of arguments supplied to the long double codegen for the function "
255                 "'f': 0 arguments were expected, but 1 arguments were provided instead"));
256 #if defined(HEYOKA_HAVE_REAL128)
257     REQUIRE_THROWS_MATCHES(f.codegen_f128(s, {}), not_implemented_error,
258                            Message("float128 codegen is not implemented for the function 'f'"));
259     REQUIRE_THROWS_MATCHES(f.codegen_f128(s, {nullptr}), std::invalid_argument,
260                            Message("Inconsistent number of arguments supplied to the float128 codegen for the function "
261                                    "'f': 0 arguments were expected, but 1 arguments were provided instead"));
262 #endif
263 
264     f = func(func_03{{}});
265     REQUIRE_THROWS_MATCHES(f.codegen_ldbl(s, {}), std::invalid_argument,
266                            Message("The long double codegen for the function 'f' returned a null pointer"));
267     REQUIRE_THROWS_MATCHES(f.codegen_dbl(s, {}), not_implemented_error,
268                            Message("double codegen is not implemented for the function 'f'"));
269 #if defined(HEYOKA_HAVE_REAL128)
270     REQUIRE_THROWS_MATCHES(f.codegen_f128(s, {}), not_implemented_error,
271                            Message("float128 codegen is not implemented for the function 'f'"));
272 #endif
273 
274 #if defined(HEYOKA_HAVE_REAL128)
275     f = func(func_04{{}});
276     REQUIRE_THROWS_MATCHES(f.codegen_f128(s, {}), std::invalid_argument,
277                            Message("The float128 codegen for the function 'f' returned a null pointer"));
278     REQUIRE_THROWS_MATCHES(f.codegen_dbl(s, {}), not_implemented_error,
279                            Message("double codegen is not implemented for the function 'f'"));
280     REQUIRE_THROWS_MATCHES(f.codegen_ldbl(s, {}), not_implemented_error,
281                            Message("long double codegen is not implemented for the function 'f'"));
282 #endif
283 }
284 
285 struct func_05 : func_base {
func_05func_05286     func_05() : func_base("f", {}) {}
func_05func_05287     explicit func_05(std::vector<expression> args) : func_base("f", std::move(args)) {}
288 
difffunc_05289     expression diff(std::unordered_map<const void *, expression> &, const std::string &) const
290     {
291         return 42_dbl;
292     }
293 };
294 
295 struct func_05a : func_base {
func_05afunc_05a296     func_05a() : func_base("f", {}) {}
func_05afunc_05a297     explicit func_05a(std::vector<expression> args) : func_base("f", std::move(args)) {}
298 
gradientfunc_05a299     std::vector<expression> gradient() const
300     {
301         return {};
302     }
303 };
304 
305 struct func_05b : func_base {
func_05bfunc_05b306     func_05b() : func_base("f", {}) {}
func_05bfunc_05b307     explicit func_05b(std::vector<expression> args) : func_base("f", std::move(args)) {}
308 
difffunc_05b309     expression diff(std::unordered_map<const void *, expression> &, const param &) const
310     {
311         return -42_dbl;
312     }
313 };
314 
315 TEST_CASE("func diff")
316 {
317     using Catch::Matchers::Message;
318 
319     auto f = func(func_05{});
320 
321     std::unordered_map<const void *, expression> func_map;
322     REQUIRE(f.diff(func_map, "x") == 42_dbl);
323     REQUIRE_THROWS_MATCHES(func(func_05a{{"x"_var}}).diff(func_map, "x"), std::invalid_argument,
324                            Message("Inconsistent gradient returned by the function 'f': a vector of 1 elements was "
325                                    "expected, but the number of elements is 0 instead"));
326     REQUIRE(func(func_05b{{"x"_var}}).diff(func_map, std::get<param>(par[0].value())) == -42_dbl);
327 }
328 
329 struct func_06 : func_base {
func_06func_06330     func_06() : func_base("f", {}) {}
func_06func_06331     explicit func_06(std::vector<expression> args) : func_base("f", std::move(args)) {}
332 
eval_dblfunc_06333     double eval_dbl(const std::unordered_map<std::string, double> &, const std::vector<double> &) const
334     {
335         return 42;
336     }
eval_ldblfunc_06337     long double eval_ldbl(const std::unordered_map<std::string, long double> &, const std::vector<long double> &) const
338     {
339         return 42;
340     }
341 #if defined(HEYOKA_HAVE_REAL128)
eval_f128func_06342     mppp::real128 eval_f128(const std::unordered_map<std::string, mppp::real128> &,
343                             const std::vector<mppp::real128> &) const
344     {
345         return mppp::real128(42);
346     }
347 #endif
348 };
349 
350 TEST_CASE("func eval_dbl")
351 {
352     auto f = func(func_06{});
353 
354     REQUIRE(f.eval_dbl({{}}, {}) == 42);
355 }
356 
357 struct func_07 : func_base {
func_07func_07358     func_07() : func_base("f", {}) {}
func_07func_07359     explicit func_07(std::vector<expression> args) : func_base("f", std::move(args)) {}
360 
eval_batch_dblfunc_07361     void eval_batch_dbl(std::vector<double> &, const std::unordered_map<std::string, std::vector<double>> &,
362                         const std::vector<double> &) const
363     {
364     }
365 };
366 
367 TEST_CASE("func eval_batch_dbl")
368 {
369     auto f = func(func_07{});
370 
371     std::vector<double> tmp;
372     REQUIRE_NOTHROW(f.eval_batch_dbl(tmp, {{}}, {}));
373 }
374 
375 struct func_08 : func_base {
func_08func_08376     func_08() : func_base("f", {}) {}
func_08func_08377     explicit func_08(std::vector<expression> args) : func_base("f", std::move(args)) {}
378 
eval_num_dblfunc_08379     double eval_num_dbl(const std::vector<double> &) const
380     {
381         return 42;
382     }
383 };
384 
385 TEST_CASE("func eval_num_dbl")
386 {
387     auto f = func(func_08{{"x"_var}});
388 
389     REQUIRE(f.eval_num_dbl({1.}) == 42);
390 }
391 
392 struct func_09 : func_base {
func_09func_09393     func_09() : func_base("f", {}) {}
func_09func_09394     explicit func_09(std::vector<expression> args) : func_base("f", std::move(args)) {}
395 
deval_num_dblfunc_09396     double deval_num_dbl(const std::vector<double> &, std::vector<double>::size_type) const
397     {
398         return 43;
399     }
400 };
401 
402 TEST_CASE("func deval_num_dbl")
403 {
404     auto f = func(func_09{{"x"_var}});
405 
406     REQUIRE(f.deval_num_dbl({1.}, 0) == 43);
407 }
408 
409 struct func_10 : func_base {
func_10func_10410     func_10() : func_base("f", {}) {}
func_10func_10411     explicit func_10(std::vector<expression> args) : func_base("f", std::move(args)) {}
412 
taylor_decomposefunc_10413     taylor_dc_t::size_type taylor_decompose(taylor_dc_t &u_vars_defs) &&
414     {
415         u_vars_defs.emplace_back("foo", std::vector<std::uint32_t>{});
416 
417         return u_vars_defs.size() - 1u;
418     }
419 };
420 
421 struct func_10a : func_base {
func_10afunc_10a422     func_10a() : func_base("f", {}) {}
func_10afunc_10a423     explicit func_10a(std::vector<expression> args) : func_base("f", std::move(args)) {}
424 
taylor_decomposefunc_10a425     taylor_dc_t::size_type taylor_decompose(taylor_dc_t &u_vars_defs) &&
426     {
427         u_vars_defs.emplace_back("foo", std::vector<std::uint32_t>{});
428 
429         return u_vars_defs.size();
430     }
431 };
432 
433 struct func_10b : func_base {
func_10bfunc_10b434     func_10b() : func_base("f", {}) {}
func_10bfunc_10b435     explicit func_10b(std::vector<expression> args) : func_base("f", std::move(args)) {}
436 
taylor_decomposefunc_10b437     taylor_dc_t::size_type taylor_decompose(taylor_dc_t &u_vars_defs) &&
438     {
439         u_vars_defs.emplace_back("foo", std::vector<std::uint32_t>{});
440 
441         return 0;
442     }
443 };
444 
445 TEST_CASE("func taylor_decompose")
446 {
447     using Catch::Matchers::Message;
448 
449     auto f = func(func_10{{"x"_var}});
450 
451     taylor_dc_t u_vars_defs{{"x"_var, {}}};
452     std::unordered_map<const void *, taylor_dc_t::size_type> func_map;
453     REQUIRE(f.taylor_decompose(func_map, u_vars_defs) == 1u);
454     REQUIRE(u_vars_defs == taylor_dc_t{{"x"_var, {}}, {"foo"_var, {}}});
455 
456     func_map = {};
457 
458     f = func(func_10a{{"x"_var}});
459 
460     REQUIRE_THROWS_MATCHES(
461         f.taylor_decompose(func_map, u_vars_defs), std::invalid_argument,
462         Message("Invalid value returned by the Taylor decomposition function for the function 'f': "
463                 "the return value is 3, which is not less than the current size of the decomposition "
464                 "(3)"));
465 
466     f = func(func_10b{{"x"_var}});
467 
468     REQUIRE_THROWS_MATCHES(f.taylor_decompose(func_map, u_vars_defs), std::invalid_argument,
469                            Message("The return value for the Taylor decomposition of a function can never be zero"));
470 }
471 
472 struct func_12 : func_base {
func_12func_12473     func_12() : func_base("f", {}) {}
func_12func_12474     explicit func_12(std::vector<expression> args) : func_base("f", std::move(args)) {}
475 
taylor_diff_dblfunc_12476     llvm::Value *taylor_diff_dbl(llvm_state &, const std::vector<std::uint32_t> &, const std::vector<llvm::Value *> &,
477                                  llvm::Value *, llvm::Value *, std::uint32_t, std::uint32_t, std::uint32_t,
478                                  std::uint32_t, bool) const
479     {
480         return nullptr;
481     }
taylor_diff_ldblfunc_12482     llvm::Value *taylor_diff_ldbl(llvm_state &, const std::vector<std::uint32_t> &, const std::vector<llvm::Value *> &,
483                                   llvm::Value *, llvm::Value *, std::uint32_t, std::uint32_t, std::uint32_t,
484                                   std::uint32_t, bool) const
485     {
486         return nullptr;
487     }
488 #if defined(HEYOKA_HAVE_REAL128)
taylor_diff_f128func_12489     llvm::Value *taylor_diff_f128(llvm_state &, const std::vector<std::uint32_t> &, const std::vector<llvm::Value *> &,
490                                   llvm::Value *, llvm::Value *, std::uint32_t, std::uint32_t, std::uint32_t,
491                                   std::uint32_t, bool) const
492     {
493         return nullptr;
494     }
495 #endif
496 };
497 
498 TEST_CASE("func taylor diff")
499 {
500     using Catch::Matchers::Message;
501 
502     auto f = func(func_12{});
503 
504     llvm_state s;
505     auto a = 0;
506     auto fake_ptr = reinterpret_cast<llvm::Value *>(&a);
507     REQUIRE_THROWS_MATCHES(f.taylor_diff_dbl(s, {}, {}, fake_ptr, fake_ptr, 1, 2, 3, 4, false), std::invalid_argument,
508                            Message("Null return value detected in func::taylor_diff_dbl() for the function 'f'"));
509     REQUIRE_THROWS_MATCHES(f.taylor_diff_ldbl(s, {}, {}, fake_ptr, fake_ptr, 1, 2, 3, 4, false), std::invalid_argument,
510                            Message("Null return value detected in func::taylor_diff_ldbl() for the function 'f'"));
511 #if defined(HEYOKA_HAVE_REAL128)
512     REQUIRE_THROWS_MATCHES(f.taylor_diff_f128(s, {}, {}, fake_ptr, fake_ptr, 1, 2, 3, 4, false), std::invalid_argument,
513                            Message("Null return value detected in func::taylor_diff_f128() for the function 'f'"));
514 #endif
515 }
516 
517 struct func_13 : func_base {
func_13func_13518     func_13() : func_base("f", {}) {}
func_13func_13519     explicit func_13(std::vector<expression> args) : func_base("f", std::move(args)) {}
520 
taylor_c_diff_func_dblfunc_13521     llvm::Function *taylor_c_diff_func_dbl(llvm_state &, std::uint32_t, std::uint32_t, bool) const
522     {
523         return nullptr;
524     }
taylor_c_diff_func_ldblfunc_13525     llvm::Function *taylor_c_diff_func_ldbl(llvm_state &, std::uint32_t, std::uint32_t, bool) const
526     {
527         return nullptr;
528     }
529 #if defined(HEYOKA_HAVE_REAL128)
taylor_c_diff_func_f128func_13530     llvm::Function *taylor_c_diff_func_f128(llvm_state &, std::uint32_t, std::uint32_t, bool) const
531     {
532         return nullptr;
533     }
534 #endif
535 };
536 
537 TEST_CASE("func taylor c_diff")
538 {
539     using Catch::Matchers::Message;
540 
541     auto f = func(func_13{});
542 
543     llvm_state s;
544     REQUIRE_THROWS_MATCHES(
545         f.taylor_c_diff_func_dbl(s, 3, 4, false), std::invalid_argument,
546         Message("Null return value detected in func::taylor_c_diff_func_dbl() for the function 'f'"));
547     REQUIRE_THROWS_MATCHES(
548         f.taylor_c_diff_func_ldbl(s, 2, 3, false), std::invalid_argument,
549         Message("Null return value detected in func::taylor_c_diff_func_ldbl() for the function 'f'"));
550 #if defined(HEYOKA_HAVE_REAL128)
551     REQUIRE_THROWS_MATCHES(
552         f.taylor_c_diff_func_f128(s, 2, 4, false), std::invalid_argument,
553         Message("Null return value detected in func::taylor_c_diff_func_f128() for the function 'f'"));
554 #endif
555 }
556 
557 TEST_CASE("func swap")
558 {
559     using std::swap;
560 
561     auto f1 = func(func_10{{"x"_var}});
562     auto f2 = func(func_12{{"y"_var}});
563 
564     swap(f1, f2);
565 
566     REQUIRE(f1.get_type_index() == typeid(func_12));
567     REQUIRE(f2.get_type_index() == typeid(func_10));
568     REQUIRE(f1.args() == std::vector{"y"_var});
569     REQUIRE(f2.args() == std::vector{"x"_var});
570 
571     REQUIRE(std::is_nothrow_swappable_v<func>);
572 }
573 
574 TEST_CASE("func ostream")
575 {
576     auto f1 = func(func_10{{"x"_var, "y"_var}});
577 
578     std::ostringstream oss;
579     oss << f1;
580 
581     REQUIRE(oss.str() == "f(x, y)");
582 
583     oss.str("");
584 
585     f1 = func(func_10{{"y"_var}});
586 
587     oss << f1;
588 
589     REQUIRE(oss.str() == "f(y)");
590 }
591 
592 TEST_CASE("func hash")
593 {
594     auto f1 = func(func_10{{"x"_var, "y"_var}});
595 
596     REQUIRE_NOTHROW(hash(f1));
597 
598     std::cout << "Hash value for f1: " << hash(f1) << '\n';
599 }
600 
601 struct func_14 : func_base {
func_14func_14602     func_14(std::string name = "pippo", std::vector<expression> args = {}) : func_base(std::move(name), std::move(args))
603     {
604     }
func_14func_14605     explicit func_14(std::vector<expression> args) : func_base("f", std::move(args)) {}
606 };
607 
608 TEST_CASE("func eq ineq")
609 {
610     auto f1 = func(func_10{{"x"_var, "y"_var}});
611 
612     REQUIRE(f1 == f1);
613     REQUIRE(!(f1 != f1));
614     REQUIRE(hash(f1) == hash(f1));
615 
616     // Differing arguments.
617     auto f2 = func(func_10{{"y"_var, "x"_var}});
618 
619     REQUIRE(f1 != f2);
620     REQUIRE(!(f1 == f2));
621 
622     auto f3 = func(func_14{{"x"_var, "y"_var}});
623     auto f4 = func(func_14{"g", {"x"_var, "y"_var}});
624 
625     // Differing names.
626     REQUIRE(f3 != f4);
627     REQUIRE(!(f3 == f4));
628 
629     // Differing underlying types.
630     f3 = func(func_10{{"x"_var, "y"_var}});
631     f4 = func(func_14{{"x"_var, "y"_var}});
632 
633     REQUIRE(f3 != f4);
634     REQUIRE(!(f3 == f4));
635 }
636 
637 TEST_CASE("func get_variables")
638 {
639     auto f1 = func(func_10{{}});
640     REQUIRE(get_variables(expression{f1}).empty());
641 
642     f1 = func(func_10{{0_dbl}});
643     REQUIRE(get_variables(expression{f1}).empty());
644 
645     f1 = func(func_10{{0_dbl, "x"_var}});
646     REQUIRE(get_variables(expression{f1}) == std::vector<std::string>{"x"});
647 
648     f1 = func(func_10{{0_dbl, "y"_var, "x"_var}});
649     REQUIRE(get_variables(expression{f1}) == std::vector<std::string>{"x", "y"});
650     f1 = func(func_10{{0_dbl, "y"_var, "x"_var, 1_dbl, "x"_var, "y"_var, "z"_var}});
651     REQUIRE(get_variables(expression{f1}) == std::vector<std::string>{"x", "y", "z"});
652 }
653 
654 TEST_CASE("func rename_variables")
655 {
656     auto f1 = expression{func(func_10{{}})};
657     auto f2 = f1;
658     rename_variables(f1, {{}});
659     REQUIRE(f2 == f1);
660 
661     f1 = expression{func(func_10{{0_dbl, "x"_var}})};
662     f2 = f1;
663     rename_variables(f1, {{}});
664     REQUIRE(f2 == f1);
665 
666     rename_variables(f1, {{"x", "y"}});
667     REQUIRE(f2 == f1);
668     REQUIRE(get_variables(expression{f1}) == std::vector<std::string>{"y"});
669     rename_variables(f1, {{"x", "y"}});
670     REQUIRE(get_variables(expression{f1}) == std::vector<std::string>{"y"});
671 
672     f1 = expression{func(func_10{{"x"_var, 0_dbl, "z"_var, "y"_var}})};
673     rename_variables(f1, {{"x", "y"}});
674     REQUIRE(f2 != f1);
675     REQUIRE(get_variables(expression{f1}) == std::vector<std::string>{"y", "z"});
676 }
677 
678 TEST_CASE("func diff free func")
679 {
680     using Catch::Matchers::Message;
681 
682     auto f1 = func(func_05{{}});
683 
684     REQUIRE(diff(expression{f1}, "x") == 42_dbl);
685 
686     f1 = func(func_00{});
687     REQUIRE_THROWS_MATCHES(diff(expression{f1}, ""), not_implemented_error,
688                            Message("Cannot compute the derivative of the function 'f' with respect to a variable, "
689                                    "because the function does not provide "
690                                    "neither a diff() nor a gradient() member function"));
691     REQUIRE_THROWS_MATCHES(diff(expression{f1}, par[0]), not_implemented_error,
692                            Message("Cannot compute the derivative of the function 'f' with respect to a parameter, "
693                                    "because the function does not provide "
694                                    "neither a diff() nor a gradient() member function"));
695 }
696 
697 struct func_15 : func_base {
func_15func_15698     func_15(std::string name = "pippo", std::vector<expression> args = {}) : func_base(std::move(name), std::move(args))
699     {
700     }
func_15func_15701     explicit func_15(std::vector<expression> args) : func_base("f", std::move(args)) {}
702 };
703 
704 TEST_CASE("func subs")
705 {
706     auto f1 = func(func_15{{"x"_var, "y"_var}});
707 
708     auto f2 = subs(expression{f1}, {{}});
709     REQUIRE(f2 == expression{f1});
710 
711     f2 = subs(expression{f1}, {{"x", "z"_var}});
712     REQUIRE(f2 == expression{func(func_15{{"z"_var, "y"_var}})});
713 
714     f2 = subs(expression{f1}, {{"x", "z"_var}, {"y", 42_dbl}});
715     REQUIRE(f2 == expression{func(func_15{{"z"_var, 42_dbl}})});
716 }
717 
718 struct func_16 : func_base {
func_16func_16719     func_16(std::string name = "pippo", std::vector<expression> args = {}) : func_base(std::move(name), std::move(args))
720     {
721     }
func_16func_16722     explicit func_16(std::vector<expression> args) : func_base("f", std::move(args)) {}
723 
to_streamfunc_16724     void to_stream(std::ostream &os) const
725     {
726         os << "Custom to stream";
727     }
728 };
729 
730 TEST_CASE("func to_stream")
731 {
732     auto f1 = func(func_15{{"x"_var, "y"_var}});
733 
734     std::cout << "Default stream: " << f1 << '\n';
735 
736     auto f2 = func(func_16{{"x"_var, "y"_var}});
737 
738     std::ostringstream oss;
739     oss << f2;
740     REQUIRE(oss.str() == "Custom to stream");
741 }
742 
743 TEST_CASE("func extract")
744 {
745     auto f1 = func(func_15{{"x"_var, "y"_var}});
746 
747     REQUIRE(f1.extract<func_15>() != nullptr);
748     REQUIRE(static_cast<const func &>(f1).extract<func_15>() != nullptr);
749 
750     REQUIRE(f1.extract<func_16>() == nullptr);
751     REQUIRE(static_cast<const func &>(f1).extract<func_16>() == nullptr);
752 
753 #if !defined(_MSC_VER) || defined(__clang__)
754     // NOTE: vanilla MSVC does not like these extraction.
755     REQUIRE(f1.extract<const func_15>() == nullptr);
756     REQUIRE(static_cast<const func &>(f1).extract<const func_15>() == nullptr);
757 
758     REQUIRE(f1.extract<int>() == nullptr);
759     REQUIRE(static_cast<const func &>(f1).extract<int>() == nullptr);
760 
761 #endif
762 }
763 
764 struct func_17 : func_base {
func_17func_17765     func_17(std::string name = "pippo", std::vector<expression> args = {}) : func_base(std::move(name), std::move(args))
766     {
767     }
func_17func_17768     explicit func_17(int n, std::vector<expression> args) : func_base("f", std::move(args)), value(n) {}
769 
extra_equal_tofunc_17770     bool extra_equal_to(const func &f) const
771     {
772         return f.extract<func_17>()->value == value;
773     }
774 
775     int value = 0;
776 };
777 
778 TEST_CASE("func extra_equal_to")
779 {
780     auto f1 = func(func_17{0, {"x"_var, "y"_var}});
781     auto f2 = func(func_17{0, {"x"_var, "y"_var}});
782     auto f3 = func(func_17{1, {"x"_var, "y"_var}});
783 
784     REQUIRE(f1 == f2);
785     REQUIRE(f1 != f3);
786 }
787 
788 struct func_18 : func_base {
func_18func_18789     func_18(std::string name = "pippo", std::vector<expression> args = {}) : func_base(std::move(name), std::move(args))
790     {
791     }
func_18func_18792     explicit func_18(int n, std::vector<expression> args) : func_base("f", std::move(args)), value(n) {}
793 
extra_hashfunc_18794     std::size_t extra_hash() const
795     {
796         return static_cast<std::size_t>(value);
797     }
798 
799     int value = 0;
800 };
801 
802 TEST_CASE("func extra_hash")
803 {
804     auto f1 = func(func_18{0, {"x"_var, "y"_var}});
805     auto f2 = func(func_18{0, {"x"_var, "y"_var}});
806     auto f3 = func(func_18{-1, {"x"_var, "y"_var}});
807 
808     REQUIRE(hash(f1) == hash(f2));
809     REQUIRE(hash(f1) != hash(f3));
810 }
811 
812 struct func_19 : func_base {
func_19func_19813     func_19(std::string name = "pippo", std::vector<expression> args = {}) : func_base(std::move(name), std::move(args))
814     {
815     }
816 
817 private:
818     friend class boost::serialization::access;
819     template <typename Archive>
serializefunc_19820     void serialize(Archive &ar, unsigned)
821     {
822         ar &boost::serialization::base_object<func_base>(*this);
823     }
824 };
825 
826 HEYOKA_S11N_FUNC_EXPORT(func_19)
827 
828 TEST_CASE("func s11n")
829 {
830     std::stringstream ss;
831 
832     func f{func_19{"pluto", {"x"_var}}};
833 
834     {
835         boost::archive::binary_oarchive oa(ss);
836 
837         oa << f;
838     }
839 
840     f = func{};
841 
842     {
843         boost::archive::binary_iarchive ia(ss);
844 
845         ia >> f;
846     }
847 
848     REQUIRE(f.get_name() == "pluto");
849     REQUIRE(f.args().size() == 1u);
850     REQUIRE(f.args()[0] == "x"_var);
851 }
852 
853 TEST_CASE("ref semantics")
854 {
855     auto [x, y, z] = make_vars("x", "y", "z");
856 
857     auto foo = (x + y) * z, bar = foo;
858 
859     REQUIRE(std::get<func>(foo.value()).get_ptr() == std::get<func>(bar.value()).get_ptr());
860 
861     foo = x - y;
862     bar = foo;
863 
864     REQUIRE(std::get<func>(foo.value()).get_ptr() == std::get<func>(bar.value()).get_ptr());
865 }
866 
867 TEST_CASE("copy")
868 {
869     auto [x, y, z] = make_vars("x", "y", "z");
870 
871     auto foo = ((x + y) * (z + x)) * ((z - x) * (y + x));
872 
873     auto foo_copy = expression{std::get<func>(foo.value()).copy()};
874 
875     // Copy creates a new obejct...
876     REQUIRE(std::get<func>(foo_copy.value()).get_ptr() != std::get<func>(foo.value()).get_ptr());
877 
878     // ... but it does not deep copy the arguments.
879     REQUIRE(std::get<func>(std::get<func>(foo_copy.value()).args()[0].value()).get_ptr()
880             == std::get<func>(std::get<func>(foo.value()).args()[0].value()).get_ptr());
881     REQUIRE(std::get<func>(std::get<func>(foo_copy.value()).args()[1].value()).get_ptr()
882             == std::get<func>(std::get<func>(foo.value()).args()[1].value()).get_ptr());
883 
884     REQUIRE(
885         std::get<func>(std::get<func>(std::get<func>(foo_copy.value()).args()[0].value()).args()[0].value()).get_ptr()
886         == std::get<func>(std::get<func>(std::get<func>(foo.value()).args()[0].value()).args()[0].value()).get_ptr());
887     REQUIRE(
888         std::get<func>(std::get<func>(std::get<func>(foo_copy.value()).args()[0].value()).args()[1].value()).get_ptr()
889         == std::get<func>(std::get<func>(std::get<func>(foo.value()).args()[0].value()).args()[1].value()).get_ptr());
890 
891     REQUIRE(
892         std::get<func>(std::get<func>(std::get<func>(foo_copy.value()).args()[1].value()).args()[0].value()).get_ptr()
893         == std::get<func>(std::get<func>(std::get<func>(foo.value()).args()[1].value()).args()[0].value()).get_ptr());
894     REQUIRE(
895         std::get<func>(std::get<func>(std::get<func>(foo_copy.value()).args()[1].value()).args()[1].value()).get_ptr()
896         == std::get<func>(std::get<func>(std::get<func>(foo.value()).args()[1].value()).args()[1].value()).get_ptr());
897 }
898