1 // Copyright 2020, 2021 Francesco Biscani (bluescarni@gmail.com), Dario Izzo (dario.izzo@gmail.com) 2 // 3 // This file is part of the heyoka library. 4 // 5 // This Source Code Form is subject to the terms of the Mozilla 6 // Public License v. 2.0. If a copy of the MPL was not distributed 7 // with this file, You can obtain one at http://mozilla.org/MPL/2.0/. 8 9 #include <cstddef> 10 #include <initializer_list> 11 #include <iostream> 12 #include <sstream> 13 #include <stdexcept> 14 #include <string> 15 #include <type_traits> 16 #include <typeinfo> 17 #include <unordered_map> 18 #include <utility> 19 #include <variant> 20 #include <vector> 21 22 #include <heyoka/config.hpp> 23 #include <heyoka/detail/llvm_fwd.hpp> 24 #include <heyoka/exceptions.hpp> 25 #include <heyoka/expression.hpp> 26 #include <heyoka/func.hpp> 27 #include <heyoka/llvm_state.hpp> 28 #include <heyoka/s11n.hpp> 29 #include <heyoka/taylor.hpp> 30 31 #include "catch.hpp" 32 33 using namespace heyoka; 34 35 struct func_00 : func_base { func_00func_0036 func_00() : func_base("f", {}) {} func_00func_0037 func_00(const std::string &name) : func_base(name, {}) {} func_00func_0038 explicit func_00(std::vector<expression> args) : func_base("f", std::move(args)) {} 39 }; 40 41 struct func_01 { 42 }; 43 44 TEST_CASE("func minimal") 45 { 46 using Catch::Matchers::Message; 47 48 func f(func_00{{"x"_var, "y"_var}}); 49 REQUIRE(f.get_type_index() == typeid(func_00)); 50 REQUIRE(f.get_name() == "f"); 51 REQUIRE(f.args() == std::vector{"x"_var, "y"_var}); 52 53 REQUIRE_THROWS_MATCHES(func{func_00{""}}, std::invalid_argument, Message("Cannot create a function with no name")); 54 55 llvm_state s; 56 auto fake_val = reinterpret_cast<llvm::Value *>(&s); 57 REQUIRE_THROWS_MATCHES(f.codegen_dbl(s, {fake_val, fake_val}), not_implemented_error, 58 Message("double codegen is not implemented for the function 'f'")); 59 REQUIRE_THROWS_MATCHES( 60 f.codegen_dbl(s, {nullptr, nullptr}), std::invalid_argument, 61 Message("Null pointer detected in the array of values passed to func::codegen_dbl() for the function 'f'")); 62 REQUIRE_THROWS_MATCHES( 63 f.codegen_ldbl(s, {nullptr, nullptr}), std::invalid_argument, 64 Message("Null pointer detected in the array of values passed to func::codegen_ldbl() for the function 'f'")); 65 #if defined(HEYOKA_HAVE_REAL128) 66 REQUIRE_THROWS_MATCHES( 67 f.codegen_f128(s, {nullptr, nullptr}), std::invalid_argument, 68 Message("Null pointer detected in the array of values passed to func::codegen_f128() for the function 'f'")); 69 #endif 70 std::unordered_map<const void *, expression> func_map; 71 REQUIRE_THROWS_MATCHES(f.diff(func_map, ""), not_implemented_error, 72 Message("Cannot compute the derivative of the function 'f' with respect to a variable, " 73 "because the function does not provide " 74 "neither a diff() nor a gradient() member function")); 75 REQUIRE_THROWS_MATCHES(f.diff(func_map, std::get<param>(par[0].value())), not_implemented_error, 76 Message("Cannot compute the derivative of the function 'f' with respect to a parameter, " 77 "because the function does not provide " 78 "neither a diff() nor a gradient() member function")); 79 REQUIRE_THROWS_MATCHES(f.eval_dbl({{}}, {}), not_implemented_error, 80 Message("double eval is not implemented for the function 'f'")); 81 std::vector<double> tmp; 82 REQUIRE_THROWS_MATCHES(f.eval_batch_dbl(tmp, {{}}, {}), not_implemented_error, 83 Message("double batch eval is not implemented for the function 'f'")); 84 REQUIRE_THROWS_MATCHES(f.eval_num_dbl({1., 1.}), not_implemented_error, 85 Message("double numerical eval is not implemented for the function 'f'")); 86 REQUIRE_THROWS_MATCHES( 87 f.eval_num_dbl({}), std::invalid_argument, 88 Message("Inconsistent number of arguments supplied to the double numerical evaluation of the function 'f': 2 " 89 "arguments were expected, but 0 arguments were provided instead")); 90 REQUIRE_THROWS_MATCHES(f.deval_num_dbl({1., 1.}, 0), not_implemented_error, 91 Message("double numerical eval of the derivative is not implemented for the function 'f'")); 92 REQUIRE_THROWS_MATCHES(f.deval_num_dbl({1.}, 0), std::invalid_argument, 93 Message("Inconsistent number of arguments supplied to the double numerical evaluation of " 94 "the derivative of function 'f': 2 " 95 "arguments were expected, but 1 arguments were provided instead")); 96 REQUIRE_THROWS_MATCHES(f.deval_num_dbl({1., 1.}, 2), std::invalid_argument, 97 Message("Invalid index supplied to the double numerical evaluation of the derivative of " 98 "function 'f': index 2 was supplied, but the number of arguments is only 2")); 99 100 REQUIRE(!std::is_constructible_v<func, func_01>); 101 102 auto orig_ptr = f.get_ptr(); 103 REQUIRE(orig_ptr == static_cast<const func &>(f).get_ptr()); 104 105 auto f2(f); 106 REQUIRE(orig_ptr == f2.get_ptr()); 107 REQUIRE(f2.get_type_index() == typeid(func_00)); 108 REQUIRE(f2.get_name() == "f"); 109 REQUIRE(f2.args() == std::vector{"x"_var, "y"_var}); 110 111 auto f3(std::move(f)); 112 REQUIRE(orig_ptr == f3.get_ptr()); 113 114 f = f3; 115 REQUIRE(f.get_ptr() == f3.get_ptr()); 116 117 f = std::move(f3); 118 REQUIRE(f.get_ptr() == orig_ptr); 119 120 auto a = 0; 121 auto fake_ptr = reinterpret_cast<llvm::Value *>(&a); 122 REQUIRE_THROWS_MATCHES(f.taylor_diff_dbl(s, {}, {nullptr, nullptr}, nullptr, nullptr, 2, 2, 2, 0, false), 123 std::invalid_argument, 124 Message("Null par_ptr detected in func::taylor_diff_dbl() for the function 'f'")); 125 REQUIRE_THROWS_MATCHES(f.taylor_diff_dbl(s, {}, {nullptr, nullptr}, fake_ptr, nullptr, 2, 2, 2, 0, false), 126 std::invalid_argument, 127 Message("Null time_ptr detected in func::taylor_diff_dbl() for the function 'f'")); 128 REQUIRE_THROWS_MATCHES(f.taylor_diff_dbl(s, {}, {nullptr, nullptr}, fake_ptr, fake_ptr, 2, 2, 2, 0, false), 129 std::invalid_argument, 130 Message("Zero batch size detected in func::taylor_diff_dbl() for the function 'f'")); 131 REQUIRE_THROWS_MATCHES( 132 f.taylor_diff_dbl(s, {}, {nullptr, nullptr}, fake_ptr, fake_ptr, 0, 2, 2, 1, false), std::invalid_argument, 133 Message("Zero number of u variables detected in func::taylor_diff_dbl() for the function 'f'")); 134 REQUIRE_THROWS_MATCHES(f.taylor_diff_dbl(s, {}, {nullptr, nullptr}, fake_ptr, fake_ptr, 2, 1, 2, 1, false), 135 not_implemented_error, 136 Message("double Taylor diff is not implemented for the function 'f'")); 137 138 REQUIRE_THROWS_MATCHES(f.taylor_diff_ldbl(s, {}, {nullptr, nullptr}, nullptr, nullptr, 2, 2, 2, 0, false), 139 std::invalid_argument, 140 Message("Null par_ptr detected in func::taylor_diff_ldbl() for the function 'f'")); 141 REQUIRE_THROWS_MATCHES(f.taylor_diff_ldbl(s, {}, {nullptr, nullptr}, fake_ptr, nullptr, 2, 2, 2, 0, false), 142 std::invalid_argument, 143 Message("Null time_ptr detected in func::taylor_diff_ldbl() for the function 'f'")); 144 REQUIRE_THROWS_MATCHES(f.taylor_diff_ldbl(s, {}, {nullptr, nullptr}, fake_ptr, fake_ptr, 2, 2, 2, 0, false), 145 std::invalid_argument, 146 Message("Zero batch size detected in func::taylor_diff_ldbl() for the function 'f'")); 147 REQUIRE_THROWS_MATCHES( 148 f.taylor_diff_ldbl(s, {}, {nullptr, nullptr}, fake_ptr, fake_ptr, 0, 2, 2, 1, false), std::invalid_argument, 149 Message("Zero number of u variables detected in func::taylor_diff_ldbl() for the function 'f'")); 150 REQUIRE_THROWS_MATCHES(f.taylor_diff_ldbl(s, {}, {nullptr, nullptr}, fake_ptr, fake_ptr, 2, 1, 2, 1, false), 151 not_implemented_error, 152 Message("long double Taylor diff is not implemented for the function 'f'")); 153 154 #if defined(HEYOKA_HAVE_REAL128) 155 REQUIRE_THROWS_MATCHES(f.taylor_diff_f128(s, {}, {nullptr, nullptr}, nullptr, nullptr, 2, 2, 2, 0, false), 156 std::invalid_argument, 157 Message("Null par_ptr detected in func::taylor_diff_f128() for the function 'f'")); 158 REQUIRE_THROWS_MATCHES(f.taylor_diff_f128(s, {}, {nullptr, nullptr}, fake_ptr, nullptr, 2, 2, 2, 0, false), 159 std::invalid_argument, 160 Message("Null time_ptr detected in func::taylor_diff_f128() for the function 'f'")); 161 REQUIRE_THROWS_MATCHES(f.taylor_diff_f128(s, {}, {nullptr, nullptr}, fake_ptr, fake_ptr, 2, 2, 2, 0, false), 162 std::invalid_argument, 163 Message("Zero batch size detected in func::taylor_diff_f128() for the function 'f'")); 164 REQUIRE_THROWS_MATCHES( 165 f.taylor_diff_f128(s, {}, {nullptr, nullptr}, fake_ptr, fake_ptr, 0, 2, 2, 1, false), std::invalid_argument, 166 Message("Zero number of u variables detected in func::taylor_diff_f128() for the function 'f'")); 167 REQUIRE_THROWS_MATCHES(f.taylor_diff_f128(s, {}, {nullptr, nullptr}, fake_ptr, fake_ptr, 2, 1, 2, 1, false), 168 not_implemented_error, 169 Message("float128 Taylor diff is not implemented for the function 'f'")); 170 #endif 171 172 REQUIRE_THROWS_MATCHES(f.taylor_c_diff_func_dbl(s, 2, 0, false), std::invalid_argument, 173 Message("Zero batch size detected in func::taylor_c_diff_func_dbl() for the function 'f'")); 174 REQUIRE_THROWS_MATCHES( 175 f.taylor_c_diff_func_dbl(s, 0, 2, false), std::invalid_argument, 176 Message("Zero number of u variables detected in func::taylor_c_diff_func_dbl() for the function 'f'")); 177 REQUIRE_THROWS_MATCHES(f.taylor_c_diff_func_dbl(s, 2, 1, false), not_implemented_error, 178 Message("double Taylor diff in compact mode is not implemented for the function 'f'")); 179 180 REQUIRE_THROWS_MATCHES(f.taylor_c_diff_func_ldbl(s, 2, 0, false), std::invalid_argument, 181 Message("Zero batch size detected in func::taylor_c_diff_func_ldbl() for the function 'f'")); 182 REQUIRE_THROWS_MATCHES( 183 f.taylor_c_diff_func_ldbl(s, 0, 2, false), std::invalid_argument, 184 Message("Zero number of u variables detected in func::taylor_c_diff_func_ldbl() for the function 'f'")); 185 REQUIRE_THROWS_MATCHES(f.taylor_c_diff_func_ldbl(s, 2, 1, false), not_implemented_error, 186 Message("long double Taylor diff in compact mode is not implemented for the function 'f'")); 187 188 #if defined(HEYOKA_HAVE_REAL128) 189 REQUIRE_THROWS_MATCHES(f.taylor_c_diff_func_f128(s, 2, 0, false), std::invalid_argument, 190 Message("Zero batch size detected in func::taylor_c_diff_func_f128() for the function 'f'")); 191 REQUIRE_THROWS_MATCHES( 192 f.taylor_c_diff_func_f128(s, 0, 2, false), std::invalid_argument, 193 Message("Zero number of u variables detected in func::taylor_c_diff_func_f128() for the function 'f'")); 194 REQUIRE_THROWS_MATCHES(f.taylor_c_diff_func_f128(s, 2, 1, false), not_implemented_error, 195 Message("float128 Taylor diff in compact mode is not implemented for the function 'f'")); 196 #endif 197 198 taylor_dc_t dec{{"x"_var, {}}}; 199 f = func{func_00{{"x"_var, "y"_var}}}; 200 std::unordered_map<const void *, taylor_dc_t::size_type> func_map2; 201 f.taylor_decompose(func_map2, dec); 202 } 203 204 struct func_02 : func_base { func_02func_02205 func_02() : func_base("f", {}) {} func_02func_02206 explicit func_02(std::vector<expression> args) : func_base("f", std::move(args)) {} 207 codegen_dblfunc_02208 llvm::Value *codegen_dbl(llvm_state &, const std::vector<llvm::Value *> &) const 209 { 210 return nullptr; 211 } 212 }; 213 214 struct func_03 : func_base { func_03func_03215 func_03() : func_base("f", {}) {} func_03func_03216 explicit func_03(std::vector<expression> args) : func_base("f", std::move(args)) {} 217 codegen_ldblfunc_03218 llvm::Value *codegen_ldbl(llvm_state &, const std::vector<llvm::Value *> &) const 219 { 220 return nullptr; 221 } 222 }; 223 224 #if defined(HEYOKA_HAVE_REAL128) 225 226 struct func_04 : func_base { func_04func_04227 func_04() : func_base("f", {}) {} func_04func_04228 explicit func_04(std::vector<expression> args) : func_base("f", std::move(args)) {} 229 codegen_f128func_04230 llvm::Value *codegen_f128(llvm_state &, const std::vector<llvm::Value *> &) const 231 { 232 return nullptr; 233 } 234 }; 235 236 #endif 237 238 TEST_CASE("func codegen") 239 { 240 using Catch::Matchers::Message; 241 242 func f(func_02{{}}); 243 244 llvm_state s; 245 REQUIRE_THROWS_MATCHES(f.codegen_dbl(s, {}), std::invalid_argument, 246 Message("The double codegen for the function 'f' returned a null pointer")); 247 REQUIRE_THROWS_MATCHES(f.codegen_dbl(s, {nullptr}), std::invalid_argument, 248 Message("Inconsistent number of arguments supplied to the double codegen for the function " 249 "'f': 0 arguments were expected, but 1 arguments were provided instead")); 250 REQUIRE_THROWS_MATCHES(f.codegen_ldbl(s, {}), not_implemented_error, 251 Message("long double codegen is not implemented for the function 'f'")); 252 REQUIRE_THROWS_MATCHES( 253 f.codegen_ldbl(s, {nullptr}), std::invalid_argument, 254 Message("Inconsistent number of arguments supplied to the long double codegen for the function " 255 "'f': 0 arguments were expected, but 1 arguments were provided instead")); 256 #if defined(HEYOKA_HAVE_REAL128) 257 REQUIRE_THROWS_MATCHES(f.codegen_f128(s, {}), not_implemented_error, 258 Message("float128 codegen is not implemented for the function 'f'")); 259 REQUIRE_THROWS_MATCHES(f.codegen_f128(s, {nullptr}), std::invalid_argument, 260 Message("Inconsistent number of arguments supplied to the float128 codegen for the function " 261 "'f': 0 arguments were expected, but 1 arguments were provided instead")); 262 #endif 263 264 f = func(func_03{{}}); 265 REQUIRE_THROWS_MATCHES(f.codegen_ldbl(s, {}), std::invalid_argument, 266 Message("The long double codegen for the function 'f' returned a null pointer")); 267 REQUIRE_THROWS_MATCHES(f.codegen_dbl(s, {}), not_implemented_error, 268 Message("double codegen is not implemented for the function 'f'")); 269 #if defined(HEYOKA_HAVE_REAL128) 270 REQUIRE_THROWS_MATCHES(f.codegen_f128(s, {}), not_implemented_error, 271 Message("float128 codegen is not implemented for the function 'f'")); 272 #endif 273 274 #if defined(HEYOKA_HAVE_REAL128) 275 f = func(func_04{{}}); 276 REQUIRE_THROWS_MATCHES(f.codegen_f128(s, {}), std::invalid_argument, 277 Message("The float128 codegen for the function 'f' returned a null pointer")); 278 REQUIRE_THROWS_MATCHES(f.codegen_dbl(s, {}), not_implemented_error, 279 Message("double codegen is not implemented for the function 'f'")); 280 REQUIRE_THROWS_MATCHES(f.codegen_ldbl(s, {}), not_implemented_error, 281 Message("long double codegen is not implemented for the function 'f'")); 282 #endif 283 } 284 285 struct func_05 : func_base { func_05func_05286 func_05() : func_base("f", {}) {} func_05func_05287 explicit func_05(std::vector<expression> args) : func_base("f", std::move(args)) {} 288 difffunc_05289 expression diff(std::unordered_map<const void *, expression> &, const std::string &) const 290 { 291 return 42_dbl; 292 } 293 }; 294 295 struct func_05a : func_base { func_05afunc_05a296 func_05a() : func_base("f", {}) {} func_05afunc_05a297 explicit func_05a(std::vector<expression> args) : func_base("f", std::move(args)) {} 298 gradientfunc_05a299 std::vector<expression> gradient() const 300 { 301 return {}; 302 } 303 }; 304 305 struct func_05b : func_base { func_05bfunc_05b306 func_05b() : func_base("f", {}) {} func_05bfunc_05b307 explicit func_05b(std::vector<expression> args) : func_base("f", std::move(args)) {} 308 difffunc_05b309 expression diff(std::unordered_map<const void *, expression> &, const param &) const 310 { 311 return -42_dbl; 312 } 313 }; 314 315 TEST_CASE("func diff") 316 { 317 using Catch::Matchers::Message; 318 319 auto f = func(func_05{}); 320 321 std::unordered_map<const void *, expression> func_map; 322 REQUIRE(f.diff(func_map, "x") == 42_dbl); 323 REQUIRE_THROWS_MATCHES(func(func_05a{{"x"_var}}).diff(func_map, "x"), std::invalid_argument, 324 Message("Inconsistent gradient returned by the function 'f': a vector of 1 elements was " 325 "expected, but the number of elements is 0 instead")); 326 REQUIRE(func(func_05b{{"x"_var}}).diff(func_map, std::get<param>(par[0].value())) == -42_dbl); 327 } 328 329 struct func_06 : func_base { func_06func_06330 func_06() : func_base("f", {}) {} func_06func_06331 explicit func_06(std::vector<expression> args) : func_base("f", std::move(args)) {} 332 eval_dblfunc_06333 double eval_dbl(const std::unordered_map<std::string, double> &, const std::vector<double> &) const 334 { 335 return 42; 336 } eval_ldblfunc_06337 long double eval_ldbl(const std::unordered_map<std::string, long double> &, const std::vector<long double> &) const 338 { 339 return 42; 340 } 341 #if defined(HEYOKA_HAVE_REAL128) eval_f128func_06342 mppp::real128 eval_f128(const std::unordered_map<std::string, mppp::real128> &, 343 const std::vector<mppp::real128> &) const 344 { 345 return mppp::real128(42); 346 } 347 #endif 348 }; 349 350 TEST_CASE("func eval_dbl") 351 { 352 auto f = func(func_06{}); 353 354 REQUIRE(f.eval_dbl({{}}, {}) == 42); 355 } 356 357 struct func_07 : func_base { func_07func_07358 func_07() : func_base("f", {}) {} func_07func_07359 explicit func_07(std::vector<expression> args) : func_base("f", std::move(args)) {} 360 eval_batch_dblfunc_07361 void eval_batch_dbl(std::vector<double> &, const std::unordered_map<std::string, std::vector<double>> &, 362 const std::vector<double> &) const 363 { 364 } 365 }; 366 367 TEST_CASE("func eval_batch_dbl") 368 { 369 auto f = func(func_07{}); 370 371 std::vector<double> tmp; 372 REQUIRE_NOTHROW(f.eval_batch_dbl(tmp, {{}}, {})); 373 } 374 375 struct func_08 : func_base { func_08func_08376 func_08() : func_base("f", {}) {} func_08func_08377 explicit func_08(std::vector<expression> args) : func_base("f", std::move(args)) {} 378 eval_num_dblfunc_08379 double eval_num_dbl(const std::vector<double> &) const 380 { 381 return 42; 382 } 383 }; 384 385 TEST_CASE("func eval_num_dbl") 386 { 387 auto f = func(func_08{{"x"_var}}); 388 389 REQUIRE(f.eval_num_dbl({1.}) == 42); 390 } 391 392 struct func_09 : func_base { func_09func_09393 func_09() : func_base("f", {}) {} func_09func_09394 explicit func_09(std::vector<expression> args) : func_base("f", std::move(args)) {} 395 deval_num_dblfunc_09396 double deval_num_dbl(const std::vector<double> &, std::vector<double>::size_type) const 397 { 398 return 43; 399 } 400 }; 401 402 TEST_CASE("func deval_num_dbl") 403 { 404 auto f = func(func_09{{"x"_var}}); 405 406 REQUIRE(f.deval_num_dbl({1.}, 0) == 43); 407 } 408 409 struct func_10 : func_base { func_10func_10410 func_10() : func_base("f", {}) {} func_10func_10411 explicit func_10(std::vector<expression> args) : func_base("f", std::move(args)) {} 412 taylor_decomposefunc_10413 taylor_dc_t::size_type taylor_decompose(taylor_dc_t &u_vars_defs) && 414 { 415 u_vars_defs.emplace_back("foo", std::vector<std::uint32_t>{}); 416 417 return u_vars_defs.size() - 1u; 418 } 419 }; 420 421 struct func_10a : func_base { func_10afunc_10a422 func_10a() : func_base("f", {}) {} func_10afunc_10a423 explicit func_10a(std::vector<expression> args) : func_base("f", std::move(args)) {} 424 taylor_decomposefunc_10a425 taylor_dc_t::size_type taylor_decompose(taylor_dc_t &u_vars_defs) && 426 { 427 u_vars_defs.emplace_back("foo", std::vector<std::uint32_t>{}); 428 429 return u_vars_defs.size(); 430 } 431 }; 432 433 struct func_10b : func_base { func_10bfunc_10b434 func_10b() : func_base("f", {}) {} func_10bfunc_10b435 explicit func_10b(std::vector<expression> args) : func_base("f", std::move(args)) {} 436 taylor_decomposefunc_10b437 taylor_dc_t::size_type taylor_decompose(taylor_dc_t &u_vars_defs) && 438 { 439 u_vars_defs.emplace_back("foo", std::vector<std::uint32_t>{}); 440 441 return 0; 442 } 443 }; 444 445 TEST_CASE("func taylor_decompose") 446 { 447 using Catch::Matchers::Message; 448 449 auto f = func(func_10{{"x"_var}}); 450 451 taylor_dc_t u_vars_defs{{"x"_var, {}}}; 452 std::unordered_map<const void *, taylor_dc_t::size_type> func_map; 453 REQUIRE(f.taylor_decompose(func_map, u_vars_defs) == 1u); 454 REQUIRE(u_vars_defs == taylor_dc_t{{"x"_var, {}}, {"foo"_var, {}}}); 455 456 func_map = {}; 457 458 f = func(func_10a{{"x"_var}}); 459 460 REQUIRE_THROWS_MATCHES( 461 f.taylor_decompose(func_map, u_vars_defs), std::invalid_argument, 462 Message("Invalid value returned by the Taylor decomposition function for the function 'f': " 463 "the return value is 3, which is not less than the current size of the decomposition " 464 "(3)")); 465 466 f = func(func_10b{{"x"_var}}); 467 468 REQUIRE_THROWS_MATCHES(f.taylor_decompose(func_map, u_vars_defs), std::invalid_argument, 469 Message("The return value for the Taylor decomposition of a function can never be zero")); 470 } 471 472 struct func_12 : func_base { func_12func_12473 func_12() : func_base("f", {}) {} func_12func_12474 explicit func_12(std::vector<expression> args) : func_base("f", std::move(args)) {} 475 taylor_diff_dblfunc_12476 llvm::Value *taylor_diff_dbl(llvm_state &, const std::vector<std::uint32_t> &, const std::vector<llvm::Value *> &, 477 llvm::Value *, llvm::Value *, std::uint32_t, std::uint32_t, std::uint32_t, 478 std::uint32_t, bool) const 479 { 480 return nullptr; 481 } taylor_diff_ldblfunc_12482 llvm::Value *taylor_diff_ldbl(llvm_state &, const std::vector<std::uint32_t> &, const std::vector<llvm::Value *> &, 483 llvm::Value *, llvm::Value *, std::uint32_t, std::uint32_t, std::uint32_t, 484 std::uint32_t, bool) const 485 { 486 return nullptr; 487 } 488 #if defined(HEYOKA_HAVE_REAL128) taylor_diff_f128func_12489 llvm::Value *taylor_diff_f128(llvm_state &, const std::vector<std::uint32_t> &, const std::vector<llvm::Value *> &, 490 llvm::Value *, llvm::Value *, std::uint32_t, std::uint32_t, std::uint32_t, 491 std::uint32_t, bool) const 492 { 493 return nullptr; 494 } 495 #endif 496 }; 497 498 TEST_CASE("func taylor diff") 499 { 500 using Catch::Matchers::Message; 501 502 auto f = func(func_12{}); 503 504 llvm_state s; 505 auto a = 0; 506 auto fake_ptr = reinterpret_cast<llvm::Value *>(&a); 507 REQUIRE_THROWS_MATCHES(f.taylor_diff_dbl(s, {}, {}, fake_ptr, fake_ptr, 1, 2, 3, 4, false), std::invalid_argument, 508 Message("Null return value detected in func::taylor_diff_dbl() for the function 'f'")); 509 REQUIRE_THROWS_MATCHES(f.taylor_diff_ldbl(s, {}, {}, fake_ptr, fake_ptr, 1, 2, 3, 4, false), std::invalid_argument, 510 Message("Null return value detected in func::taylor_diff_ldbl() for the function 'f'")); 511 #if defined(HEYOKA_HAVE_REAL128) 512 REQUIRE_THROWS_MATCHES(f.taylor_diff_f128(s, {}, {}, fake_ptr, fake_ptr, 1, 2, 3, 4, false), std::invalid_argument, 513 Message("Null return value detected in func::taylor_diff_f128() for the function 'f'")); 514 #endif 515 } 516 517 struct func_13 : func_base { func_13func_13518 func_13() : func_base("f", {}) {} func_13func_13519 explicit func_13(std::vector<expression> args) : func_base("f", std::move(args)) {} 520 taylor_c_diff_func_dblfunc_13521 llvm::Function *taylor_c_diff_func_dbl(llvm_state &, std::uint32_t, std::uint32_t, bool) const 522 { 523 return nullptr; 524 } taylor_c_diff_func_ldblfunc_13525 llvm::Function *taylor_c_diff_func_ldbl(llvm_state &, std::uint32_t, std::uint32_t, bool) const 526 { 527 return nullptr; 528 } 529 #if defined(HEYOKA_HAVE_REAL128) taylor_c_diff_func_f128func_13530 llvm::Function *taylor_c_diff_func_f128(llvm_state &, std::uint32_t, std::uint32_t, bool) const 531 { 532 return nullptr; 533 } 534 #endif 535 }; 536 537 TEST_CASE("func taylor c_diff") 538 { 539 using Catch::Matchers::Message; 540 541 auto f = func(func_13{}); 542 543 llvm_state s; 544 REQUIRE_THROWS_MATCHES( 545 f.taylor_c_diff_func_dbl(s, 3, 4, false), std::invalid_argument, 546 Message("Null return value detected in func::taylor_c_diff_func_dbl() for the function 'f'")); 547 REQUIRE_THROWS_MATCHES( 548 f.taylor_c_diff_func_ldbl(s, 2, 3, false), std::invalid_argument, 549 Message("Null return value detected in func::taylor_c_diff_func_ldbl() for the function 'f'")); 550 #if defined(HEYOKA_HAVE_REAL128) 551 REQUIRE_THROWS_MATCHES( 552 f.taylor_c_diff_func_f128(s, 2, 4, false), std::invalid_argument, 553 Message("Null return value detected in func::taylor_c_diff_func_f128() for the function 'f'")); 554 #endif 555 } 556 557 TEST_CASE("func swap") 558 { 559 using std::swap; 560 561 auto f1 = func(func_10{{"x"_var}}); 562 auto f2 = func(func_12{{"y"_var}}); 563 564 swap(f1, f2); 565 566 REQUIRE(f1.get_type_index() == typeid(func_12)); 567 REQUIRE(f2.get_type_index() == typeid(func_10)); 568 REQUIRE(f1.args() == std::vector{"y"_var}); 569 REQUIRE(f2.args() == std::vector{"x"_var}); 570 571 REQUIRE(std::is_nothrow_swappable_v<func>); 572 } 573 574 TEST_CASE("func ostream") 575 { 576 auto f1 = func(func_10{{"x"_var, "y"_var}}); 577 578 std::ostringstream oss; 579 oss << f1; 580 581 REQUIRE(oss.str() == "f(x, y)"); 582 583 oss.str(""); 584 585 f1 = func(func_10{{"y"_var}}); 586 587 oss << f1; 588 589 REQUIRE(oss.str() == "f(y)"); 590 } 591 592 TEST_CASE("func hash") 593 { 594 auto f1 = func(func_10{{"x"_var, "y"_var}}); 595 596 REQUIRE_NOTHROW(hash(f1)); 597 598 std::cout << "Hash value for f1: " << hash(f1) << '\n'; 599 } 600 601 struct func_14 : func_base { func_14func_14602 func_14(std::string name = "pippo", std::vector<expression> args = {}) : func_base(std::move(name), std::move(args)) 603 { 604 } func_14func_14605 explicit func_14(std::vector<expression> args) : func_base("f", std::move(args)) {} 606 }; 607 608 TEST_CASE("func eq ineq") 609 { 610 auto f1 = func(func_10{{"x"_var, "y"_var}}); 611 612 REQUIRE(f1 == f1); 613 REQUIRE(!(f1 != f1)); 614 REQUIRE(hash(f1) == hash(f1)); 615 616 // Differing arguments. 617 auto f2 = func(func_10{{"y"_var, "x"_var}}); 618 619 REQUIRE(f1 != f2); 620 REQUIRE(!(f1 == f2)); 621 622 auto f3 = func(func_14{{"x"_var, "y"_var}}); 623 auto f4 = func(func_14{"g", {"x"_var, "y"_var}}); 624 625 // Differing names. 626 REQUIRE(f3 != f4); 627 REQUIRE(!(f3 == f4)); 628 629 // Differing underlying types. 630 f3 = func(func_10{{"x"_var, "y"_var}}); 631 f4 = func(func_14{{"x"_var, "y"_var}}); 632 633 REQUIRE(f3 != f4); 634 REQUIRE(!(f3 == f4)); 635 } 636 637 TEST_CASE("func get_variables") 638 { 639 auto f1 = func(func_10{{}}); 640 REQUIRE(get_variables(expression{f1}).empty()); 641 642 f1 = func(func_10{{0_dbl}}); 643 REQUIRE(get_variables(expression{f1}).empty()); 644 645 f1 = func(func_10{{0_dbl, "x"_var}}); 646 REQUIRE(get_variables(expression{f1}) == std::vector<std::string>{"x"}); 647 648 f1 = func(func_10{{0_dbl, "y"_var, "x"_var}}); 649 REQUIRE(get_variables(expression{f1}) == std::vector<std::string>{"x", "y"}); 650 f1 = func(func_10{{0_dbl, "y"_var, "x"_var, 1_dbl, "x"_var, "y"_var, "z"_var}}); 651 REQUIRE(get_variables(expression{f1}) == std::vector<std::string>{"x", "y", "z"}); 652 } 653 654 TEST_CASE("func rename_variables") 655 { 656 auto f1 = expression{func(func_10{{}})}; 657 auto f2 = f1; 658 rename_variables(f1, {{}}); 659 REQUIRE(f2 == f1); 660 661 f1 = expression{func(func_10{{0_dbl, "x"_var}})}; 662 f2 = f1; 663 rename_variables(f1, {{}}); 664 REQUIRE(f2 == f1); 665 666 rename_variables(f1, {{"x", "y"}}); 667 REQUIRE(f2 == f1); 668 REQUIRE(get_variables(expression{f1}) == std::vector<std::string>{"y"}); 669 rename_variables(f1, {{"x", "y"}}); 670 REQUIRE(get_variables(expression{f1}) == std::vector<std::string>{"y"}); 671 672 f1 = expression{func(func_10{{"x"_var, 0_dbl, "z"_var, "y"_var}})}; 673 rename_variables(f1, {{"x", "y"}}); 674 REQUIRE(f2 != f1); 675 REQUIRE(get_variables(expression{f1}) == std::vector<std::string>{"y", "z"}); 676 } 677 678 TEST_CASE("func diff free func") 679 { 680 using Catch::Matchers::Message; 681 682 auto f1 = func(func_05{{}}); 683 684 REQUIRE(diff(expression{f1}, "x") == 42_dbl); 685 686 f1 = func(func_00{}); 687 REQUIRE_THROWS_MATCHES(diff(expression{f1}, ""), not_implemented_error, 688 Message("Cannot compute the derivative of the function 'f' with respect to a variable, " 689 "because the function does not provide " 690 "neither a diff() nor a gradient() member function")); 691 REQUIRE_THROWS_MATCHES(diff(expression{f1}, par[0]), not_implemented_error, 692 Message("Cannot compute the derivative of the function 'f' with respect to a parameter, " 693 "because the function does not provide " 694 "neither a diff() nor a gradient() member function")); 695 } 696 697 struct func_15 : func_base { func_15func_15698 func_15(std::string name = "pippo", std::vector<expression> args = {}) : func_base(std::move(name), std::move(args)) 699 { 700 } func_15func_15701 explicit func_15(std::vector<expression> args) : func_base("f", std::move(args)) {} 702 }; 703 704 TEST_CASE("func subs") 705 { 706 auto f1 = func(func_15{{"x"_var, "y"_var}}); 707 708 auto f2 = subs(expression{f1}, {{}}); 709 REQUIRE(f2 == expression{f1}); 710 711 f2 = subs(expression{f1}, {{"x", "z"_var}}); 712 REQUIRE(f2 == expression{func(func_15{{"z"_var, "y"_var}})}); 713 714 f2 = subs(expression{f1}, {{"x", "z"_var}, {"y", 42_dbl}}); 715 REQUIRE(f2 == expression{func(func_15{{"z"_var, 42_dbl}})}); 716 } 717 718 struct func_16 : func_base { func_16func_16719 func_16(std::string name = "pippo", std::vector<expression> args = {}) : func_base(std::move(name), std::move(args)) 720 { 721 } func_16func_16722 explicit func_16(std::vector<expression> args) : func_base("f", std::move(args)) {} 723 to_streamfunc_16724 void to_stream(std::ostream &os) const 725 { 726 os << "Custom to stream"; 727 } 728 }; 729 730 TEST_CASE("func to_stream") 731 { 732 auto f1 = func(func_15{{"x"_var, "y"_var}}); 733 734 std::cout << "Default stream: " << f1 << '\n'; 735 736 auto f2 = func(func_16{{"x"_var, "y"_var}}); 737 738 std::ostringstream oss; 739 oss << f2; 740 REQUIRE(oss.str() == "Custom to stream"); 741 } 742 743 TEST_CASE("func extract") 744 { 745 auto f1 = func(func_15{{"x"_var, "y"_var}}); 746 747 REQUIRE(f1.extract<func_15>() != nullptr); 748 REQUIRE(static_cast<const func &>(f1).extract<func_15>() != nullptr); 749 750 REQUIRE(f1.extract<func_16>() == nullptr); 751 REQUIRE(static_cast<const func &>(f1).extract<func_16>() == nullptr); 752 753 #if !defined(_MSC_VER) || defined(__clang__) 754 // NOTE: vanilla MSVC does not like these extraction. 755 REQUIRE(f1.extract<const func_15>() == nullptr); 756 REQUIRE(static_cast<const func &>(f1).extract<const func_15>() == nullptr); 757 758 REQUIRE(f1.extract<int>() == nullptr); 759 REQUIRE(static_cast<const func &>(f1).extract<int>() == nullptr); 760 761 #endif 762 } 763 764 struct func_17 : func_base { func_17func_17765 func_17(std::string name = "pippo", std::vector<expression> args = {}) : func_base(std::move(name), std::move(args)) 766 { 767 } func_17func_17768 explicit func_17(int n, std::vector<expression> args) : func_base("f", std::move(args)), value(n) {} 769 extra_equal_tofunc_17770 bool extra_equal_to(const func &f) const 771 { 772 return f.extract<func_17>()->value == value; 773 } 774 775 int value = 0; 776 }; 777 778 TEST_CASE("func extra_equal_to") 779 { 780 auto f1 = func(func_17{0, {"x"_var, "y"_var}}); 781 auto f2 = func(func_17{0, {"x"_var, "y"_var}}); 782 auto f3 = func(func_17{1, {"x"_var, "y"_var}}); 783 784 REQUIRE(f1 == f2); 785 REQUIRE(f1 != f3); 786 } 787 788 struct func_18 : func_base { func_18func_18789 func_18(std::string name = "pippo", std::vector<expression> args = {}) : func_base(std::move(name), std::move(args)) 790 { 791 } func_18func_18792 explicit func_18(int n, std::vector<expression> args) : func_base("f", std::move(args)), value(n) {} 793 extra_hashfunc_18794 std::size_t extra_hash() const 795 { 796 return static_cast<std::size_t>(value); 797 } 798 799 int value = 0; 800 }; 801 802 TEST_CASE("func extra_hash") 803 { 804 auto f1 = func(func_18{0, {"x"_var, "y"_var}}); 805 auto f2 = func(func_18{0, {"x"_var, "y"_var}}); 806 auto f3 = func(func_18{-1, {"x"_var, "y"_var}}); 807 808 REQUIRE(hash(f1) == hash(f2)); 809 REQUIRE(hash(f1) != hash(f3)); 810 } 811 812 struct func_19 : func_base { func_19func_19813 func_19(std::string name = "pippo", std::vector<expression> args = {}) : func_base(std::move(name), std::move(args)) 814 { 815 } 816 817 private: 818 friend class boost::serialization::access; 819 template <typename Archive> serializefunc_19820 void serialize(Archive &ar, unsigned) 821 { 822 ar &boost::serialization::base_object<func_base>(*this); 823 } 824 }; 825 826 HEYOKA_S11N_FUNC_EXPORT(func_19) 827 828 TEST_CASE("func s11n") 829 { 830 std::stringstream ss; 831 832 func f{func_19{"pluto", {"x"_var}}}; 833 834 { 835 boost::archive::binary_oarchive oa(ss); 836 837 oa << f; 838 } 839 840 f = func{}; 841 842 { 843 boost::archive::binary_iarchive ia(ss); 844 845 ia >> f; 846 } 847 848 REQUIRE(f.get_name() == "pluto"); 849 REQUIRE(f.args().size() == 1u); 850 REQUIRE(f.args()[0] == "x"_var); 851 } 852 853 TEST_CASE("ref semantics") 854 { 855 auto [x, y, z] = make_vars("x", "y", "z"); 856 857 auto foo = (x + y) * z, bar = foo; 858 859 REQUIRE(std::get<func>(foo.value()).get_ptr() == std::get<func>(bar.value()).get_ptr()); 860 861 foo = x - y; 862 bar = foo; 863 864 REQUIRE(std::get<func>(foo.value()).get_ptr() == std::get<func>(bar.value()).get_ptr()); 865 } 866 867 TEST_CASE("copy") 868 { 869 auto [x, y, z] = make_vars("x", "y", "z"); 870 871 auto foo = ((x + y) * (z + x)) * ((z - x) * (y + x)); 872 873 auto foo_copy = expression{std::get<func>(foo.value()).copy()}; 874 875 // Copy creates a new obejct... 876 REQUIRE(std::get<func>(foo_copy.value()).get_ptr() != std::get<func>(foo.value()).get_ptr()); 877 878 // ... but it does not deep copy the arguments. 879 REQUIRE(std::get<func>(std::get<func>(foo_copy.value()).args()[0].value()).get_ptr() 880 == std::get<func>(std::get<func>(foo.value()).args()[0].value()).get_ptr()); 881 REQUIRE(std::get<func>(std::get<func>(foo_copy.value()).args()[1].value()).get_ptr() 882 == std::get<func>(std::get<func>(foo.value()).args()[1].value()).get_ptr()); 883 884 REQUIRE( 885 std::get<func>(std::get<func>(std::get<func>(foo_copy.value()).args()[0].value()).args()[0].value()).get_ptr() 886 == std::get<func>(std::get<func>(std::get<func>(foo.value()).args()[0].value()).args()[0].value()).get_ptr()); 887 REQUIRE( 888 std::get<func>(std::get<func>(std::get<func>(foo_copy.value()).args()[0].value()).args()[1].value()).get_ptr() 889 == std::get<func>(std::get<func>(std::get<func>(foo.value()).args()[0].value()).args()[1].value()).get_ptr()); 890 891 REQUIRE( 892 std::get<func>(std::get<func>(std::get<func>(foo_copy.value()).args()[1].value()).args()[0].value()).get_ptr() 893 == std::get<func>(std::get<func>(std::get<func>(foo.value()).args()[1].value()).args()[0].value()).get_ptr()); 894 REQUIRE( 895 std::get<func>(std::get<func>(std::get<func>(foo_copy.value()).args()[1].value()).args()[1].value()).get_ptr() 896 == std::get<func>(std::get<func>(std::get<func>(foo.value()).args()[1].value()).args()[1].value()).get_ptr()); 897 } 898