1 // --------------------------------------------------------------------- 2 // 3 // Copyright (C) 2020 by the deal.II authors 4 // 5 // This file is part of the deal.II library. 6 // 7 // The deal.II library is free software; you can use it, redistribute 8 // it, and/or modify it under the terms of the GNU Lesser General 9 // Public License as published by the Free Software Foundation; either 10 // version 2.1 of the License, or (at your option) any later version. 11 // The full text of the license can be found in the file LICENSE at 12 // the top level of the deal.II distribution. 13 // 14 // --------------------------------------------------------------------- 15 16 #ifndef dealii_differentiation_sd_symengine_number_visitor_internal_h 17 #define dealii_differentiation_sd_symengine_number_visitor_internal_h 18 19 #include <deal.II/base/config.h> 20 21 #ifdef DEAL_II_WITH_SYMENGINE 22 23 DEAL_II_DISABLE_EXTRA_DIAGNOSTICS 24 // Low level 25 # include <symengine/basic.h> 26 # include <symengine/dict.h> 27 # include <symengine/symengine_exception.h> 28 # include <symengine/symengine_rcp.h> 29 30 // Visitor 31 # include <symengine/visitor.h> 32 33 DEAL_II_ENABLE_EXTRA_DIAGNOSTICS 34 35 # include <deal.II/base/exceptions.h> 36 # include <deal.II/base/numbers.h> 37 38 # include <deal.II/differentiation/sd/symengine_number_types.h> 39 # include <deal.II/differentiation/sd/symengine_utilities.h> 40 41 # include <boost/serialization/split_member.hpp> 42 43 44 DEAL_II_NAMESPACE_OPEN 45 46 47 namespace Differentiation 48 { 49 namespace SD 50 { 51 namespace internal 52 { 53 /** 54 * A class that implements common subexpression elimination 55 * for dictionary visitor classes. 56 * 57 * It is intended that this class only be used in conjunction 58 * with the DictionarySubstitutionVisitor. 59 */ 60 template <typename ReturnType, typename ExpressionType> 61 class CSEDictionaryVisitor 62 { 63 using symbol_vector_pair = 64 std::vector<std::pair<SD::Expression, SD::Expression>>; 65 66 public: 67 /* 68 * Constructor. 69 */ 70 CSEDictionaryVisitor() = default; 71 72 /* 73 * Destructor. 74 */ 75 virtual ~CSEDictionaryVisitor() = default; 76 77 /** 78 * Initialize and perform common subexpression elimination. 79 * 80 * Here we build the reduced expressions for the @p dependent_functions 81 * as well as a list of intermediate, repeating symbolic expressions 82 * that are extracted @p dependent_functions. This operation leads to 83 * the elimination of repeated expressions, so they only have to be 84 * evaluated once. 85 * 86 * @param dependent_functions A vector of expressions that represent 87 * the dependent functions that CSE will be performed on. 88 */ 89 void 90 init(const types::symbol_vector &dependent_functions); 91 92 /** 93 * Evaluate the (reduced) dependent functions using the common 94 * subexpressions generated during the call to init(). 95 * 96 * The @p output_values are the numerical result of substituting 97 * each of the @p substitution_values for their corresponding entry 98 * of the @p independent_symbols. Specifically, we first calculate 99 * the numerical values of each reduced subexpression, and feed those 100 * into a reduced equivalent of the dependent variables. The result of 101 * this is then written into the @p output_values array. 102 * 103 * @param[out] output_values A pointer to the first element in an array 104 * of type @p ReturnType. After this call, the underlying array will 105 * hold the numerical result of substituting the @p substitution_values 106 * into the pre-registered dependent functions. 107 * @param[in] independent_symbols A vector of symbols that represent 108 * the independent variables that are arguments to the previously 109 * defined dependent functions. 110 * @param[in] substitution_values A pointer to the first element in an 111 * array of type @p ReturnType. Each entry in this array stores the 112 * numerical value that an independent variable is to take for the 113 * purpose of value substitution. 114 * 115 * @note It is expected that both the @p output_values and 116 * @p substitution_values arrays be correctly dimensioned, as there is 117 * no range checking performed on these data structures. 118 * The @p output_values array should have the same number of elements as 119 * the dependent functions first passed to this class in the call to 120 * init(). Similarly, the @p substitution_values array should have the 121 * same number of elements as the @p independent_symbols vector has. 122 * 123 * @note It is expected that there be a 1-1 correspondence between the 124 * entries in @p independent_symbols and @p substitution_values. This 125 * is not checked within this function, so it is up to the user to 126 * ensure that the relationships between the independent variables and 127 * their numerical value be correctly set up and maintained. 128 */ 129 void 130 call(ReturnType * output_values, 131 const types::symbol_vector &independent_symbols, 132 const ReturnType * substitution_values); 133 134 /** 135 * Write the data of this object to a stream for the purpose 136 * of serialization. 137 */ 138 template <class Archive> 139 void 140 save(Archive &archive, const unsigned int version) const; 141 142 /** 143 * Read the data for this object from a stream for the purpose 144 * of serialization. 145 */ 146 template <class Archive> 147 void 148 load(Archive &archive, const unsigned int version); 149 150 # ifdef DOXYGEN 151 /** 152 * Write and read the data of this object from a stream for the purpose 153 * of serialization. 154 */ 155 template <class Archive> 156 void 157 serialize(Archive &archive, const unsigned int version); 158 # else 159 // This macro defines the serialize() method that is compatible with 160 // the templated save() and load() method that have been implemented. 161 BOOST_SERIALIZATION_SPLIT_MEMBER() 162 # endif 163 164 /** 165 * Print all of the intermediate reduced expressions, as well as 166 * the final reduced expressions for dependent variables to the 167 * @p stream. 168 */ 169 template <typename StreamType> 170 void 171 print(StreamType &stream) const; 172 173 /** 174 * Return a flag stating whether we've performed CSE or not. 175 */ 176 bool 177 executed() const; 178 179 /** 180 * The number of intermediate expressions that must be evaluated as 181 * part of the collection of common subexpressions. 182 */ 183 unsigned int 184 n_intermediate_expressions() const; 185 186 /** 187 * The size of the final set of reduced expressions. 188 */ 189 unsigned int 190 n_reduced_expressions() const; 191 192 protected: 193 /** 194 * Initialize and perform common subexpression elimination. 195 * 196 * This function performs the same action as the other init() function, 197 * except that it works with native SymEngine data types; 198 * the single argument is a vector of 199 * `SymEngine::RCP<const SymEngine::Basic>`. 200 * 201 * @param dependent_functions A vector of expressions that represent 202 * the dependent functions that CSE will be performed on. 203 */ 204 void 205 init(const SymEngine::vec_basic &dependent_functions); 206 207 /** 208 * Evaluate the (reduced) dependent functions using the common 209 * subexpressions generated during the call to init(). 210 * 211 * This function performs the same action as the other init() function, 212 * except that it works with native SymEngine data types. 213 * The @p independent_symbols is represented by a vector of 214 * `SymEngine::RCP<const SymEngine::Basic>`. 215 * 216 * @param[out] output_values A pointer to the first element in an array 217 * of type @p ReturnType. After this call, the underlying array will 218 * hold the numerical result of substituting the @p substitution_values 219 * into the pre-registered dependent functions. 220 * @param[in] independent_symbols A vector of symbols that represent 221 * the independent variables that are arguments to the previously 222 * defined dependent functions. 223 * @param[in] substitution_values A pointer to the first element in an 224 * array of type @p ReturnType. Each entry in this array stores the 225 * numerical value that an independent variable is to take for the 226 * purpose of value substitution. 227 * 228 * @note The caveats described in the documentation of the other call() 229 * function apply here as well. 230 */ 231 void 232 call(ReturnType * output_values, 233 const SymEngine::vec_basic &independent_symbols, 234 const ReturnType * substitution_values); 235 236 private: 237 // Note: It would be more efficient to store this data in native 238 // SymEngine types, as it would prevent some copying of the data 239 // structures. However, this makes serialization more difficult, 240 // so we use our own serializable types instead, and lose a bit 241 // of efficiency. 242 243 /** 244 * Intermediate symbols and their definition. 245 */ 246 symbol_vector_pair intermediate_symbols_exprs; 247 248 /** 249 * Final reduced expressions. 250 */ 251 types::symbol_vector reduced_exprs; 252 }; 253 254 255 256 /** 257 * A class to perform dictionary-based substitution as if it were an 258 * optimizer of the "lambda" or "LLVM" variety. 259 * 260 * This class is only really useful to assist in the easy switching 261 * between different optimizers and, more importantly, for integrating 262 * CSE into a dictionary substitution scheme. It is therefore only 263 * intended to be created and used by a BatchOptimizer. 264 */ 265 template <typename ReturnType, typename ExpressionType> 266 class DictionarySubstitutionVisitor 267 : public SymEngine::BaseVisitor< 268 DictionarySubstitutionVisitor<ReturnType, ExpressionType>> 269 { 270 public: 271 /* 272 * Constructor. 273 */ 274 DictionarySubstitutionVisitor() = default; 275 276 /* 277 * Destructor. 278 */ 279 virtual ~DictionarySubstitutionVisitor() override = default; 280 281 /** 282 * Initialization, and registration of the independent and dependent 283 * variables. 284 * 285 * This variation of the initialization function registers a single 286 * dependent expression. If the @p use_cse is set to `true`, then common 287 * subexpression elimination is also performed at the time of 288 * initialization. 289 * 290 * @param independent_symbols A vector of symbols that represent 291 * the independent variables that are arguments to the 292 * @p dependent_function. 293 * @param dependent_function A single symbolic expression that 294 * represents a dependent variable. 295 * @param use_cse A flag to indicate whether or not to use common 296 * subexpression elimination. 297 * 298 * @note After this function is called, no further registration of 299 * dependent functions can be performed. If it is desired that multiple 300 * dependent expressions be registered, then the other variant of this 301 * function that takes in a vector of dependent expressions should be 302 * used. 303 */ 304 void 305 init(const types::symbol_vector &independent_symbols, 306 const Expression & dependent_function, 307 const bool use_cse = false); 308 309 /** 310 * Initialization, and registration of the independent and dependent 311 * variables. 312 * 313 * This function performs the same action as the other init() function, 314 * described above, except that it works with native SymEngine data 315 * types. The @p independent_symbols are represented by a vector of 316 * `SymEngine::RCP<const SymEngine::Basic>`, and the 317 * @p dependent_function is of type 318 * `SymEngine::RCP<const SymEngine::Basic>`. 319 * 320 * @param independent_symbols A vector of symbols that represent 321 * the independent variables that are arguments to the 322 * @p dependent_function. 323 * @param dependent_function A single symbolic expression that 324 * represents a dependent variable. 325 * @param use_cse A flag to indicate whether or not to use common 326 * subexpression elimination. 327 * 328 * @note The caveats described in the documentation of the other init() 329 * function apply here as well. 330 */ 331 // The following definition is required due to base class CRTP. 332 void 333 init(const SymEngine::vec_basic &independent_symbols, 334 const SymEngine::Basic & dependent_function, 335 const bool use_cse = false); 336 337 /** 338 * Initialization, and registration of the independent and dependent 339 * variables. 340 * 341 * This variation of the initialization function registers a vector of 342 * dependent expressions. If the @p use_cse is set to `true`, then 343 * common subexpression elimination is also performed at the time of 344 * initialization. 345 * 346 * @param independent_symbols A vector of symbols that represent 347 * the independent variables that are arguments to the 348 * @p dependent_functions. 349 * @param dependent_functions A vector of symbolic expressions that 350 * represent the dependent variables. 351 * @param use_cse A flag to indicate whether or not to use common 352 * subexpression elimination. 353 * 354 * @note After this function is called, no further registration of 355 * dependent functions can be performed. 356 */ 357 void 358 init(const types::symbol_vector &independent_symbols, 359 const types::symbol_vector &dependent_functions, 360 const bool use_cse = false); 361 362 363 /** 364 * Initialization, and registration of the independent and dependent 365 * variables. 366 * 367 * This function performs the same action as the other init() function, 368 * described above, except that it works with native SymEngine data 369 * types. Both the @p independent_symbols and @p dependent_functions 370 * are represented by a vector of 371 * `SymEngine::RCP<const SymEngine::Basic>`. 372 * 373 * @param independent_symbols A vector of symbols that represent 374 * the independent variables that are arguments to the 375 * @p dependent_functions. 376 * @param dependent_functions A vector of symbolic expressions that 377 * represent the dependent variables. 378 * @param use_cse A flag to indicate whether or not to use common 379 * subexpression elimination. 380 * 381 * @note The caveats described in the documentation of the other init() 382 * function apply here as well. 383 */ 384 // The following definition is required due to base class CRTP. 385 void 386 init(const SymEngine::vec_basic &independent_symbols, 387 const SymEngine::vec_basic &dependent_functions, 388 const bool use_cse = false); 389 390 /** 391 * Evaluate the dependent functions that were registered at 392 * initializtion time. 393 * 394 * The @p output_values are the numerical result of substituting 395 * each of the @p substitution_values for their corresponding entry 396 * of the pre-registered independent variables. 397 * 398 * @param[out] output_values A pointer to the first element in an array 399 * of type @p ReturnType. After this call, the underlying array will 400 * hold the numerical result of substituting the @p substitution_values 401 * into the pre-registered dependent functions. 402 * @param[in] substitution_values A pointer to the first element in an 403 * array of type @p ReturnType. Each entry in this array stores the 404 * numerical value that an independent variable is to take for the 405 * purpose of value substitution. 406 * 407 * @note It is expected that both the @p output_values and 408 * @p substitution_values arrays be correctly dimensioned, as there is 409 * no range checking performed on these data structures. 410 * The @p output_values array should have the same number of elements as 411 * the dependent functions first passed to this class in the call to 412 * init(). Similarly, the @p substitution_values array should have the 413 * same number of elements as the there were independent variables 414 * that were registered at the time of initialization. 415 * 416 * @note It is expected that there be a 1-1 correspondence between the 417 * entries in independent variables and @p substitution_values. This 418 * is not checked within this function, so it is up to the user to 419 * ensure that the relationships between the independent variables and 420 * their numerical value be correctly set up and maintained. 421 */ 422 void 423 call(ReturnType *output_values, const ReturnType *substitution_values); 424 425 /** 426 * Evaluate the dependent function that were registered at 427 * initializtion time. 428 * 429 * The purpose of this function is the same as the other call() 430 * functions, but 431 * 432 * @param substitution_values A vector that stores the numerical values 433 * that each independent variable that was registered with the class 434 * instance is to take for the purpose of value substitution. 435 * @return ReturnType The numerical value associated with the single 436 * dependent function that is registered with this class. 437 * 438 * @note It is expected that there be a 1-1 correspondence between the 439 * entries in independent variables and @p substitution_values. 440 */ 441 // The following definition is required due to base class CRTP. 442 ReturnType 443 call(const std::vector<ReturnType> &substitution_values); 444 445 /** 446 * Write the data of this object to a stream for the purpose 447 * of serialization. 448 */ 449 template <class Archive> 450 void 451 save(Archive &archive, const unsigned int version) const; 452 453 /** 454 * Read the data for this object from a stream for the purpose 455 * of serialization. 456 */ 457 template <class Archive> 458 void 459 load(Archive &archive, const unsigned int version); 460 461 # ifdef DOXYGEN 462 /** 463 * Write and read the data of this object from a stream for the purpose 464 * of serialization. 465 */ 466 template <class Archive> 467 void 468 serialize(Archive &archive, const unsigned int version); 469 # else 470 // This macro defines the serialize() method that is compatible with 471 // the templated save() and load() method that have been implemented. 472 BOOST_SERIALIZATION_SPLIT_MEMBER() 473 # endif 474 475 /** 476 * Print some information on state of the internal data 477 * structures stored in the class. 478 * 479 * @tparam Stream The type for the output stream. 480 * @param stream The output stream to print to. 481 * @param print_independent_symbols A flag to indicate if the independent 482 * variables should be outputted to the @p stream. 483 * @param print_dependent_functions A flag to indicate if the dependent 484 * expressions should be outputted to the @p stream. 485 * @param print_cse_reductions A flag to indicate whether or not all 486 * common subexpressions should be printed to the @p stream. 487 */ 488 template <typename StreamType> 489 void 490 print(StreamType &stream, 491 const bool print_independent_symbols = false, 492 const bool print_dependent_functions = false, 493 const bool print_cse_reductions = false) const; 494 495 # ifndef DOXYGEN 496 // The following definitions are required due to base class CRTP. 497 // Since these are not used, and therefore not important to 498 // understand, we'll define them in the most concise manner possible. 499 // We also won't bother to document their existence, since they cannot 500 // be used. 501 # define IMPLEMENT_DSV_BVISIT(Argument) \ 502 void bvisit(const Argument &) \ 503 { \ 504 AssertThrow(false, ExcNotImplemented()); \ 505 } 506 507 IMPLEMENT_DSV_BVISIT(SymEngine::Basic) 508 IMPLEMENT_DSV_BVISIT(SymEngine::Symbol) 509 IMPLEMENT_DSV_BVISIT(SymEngine::Constant) 510 IMPLEMENT_DSV_BVISIT(SymEngine::Integer) 511 IMPLEMENT_DSV_BVISIT(SymEngine::Rational) 512 IMPLEMENT_DSV_BVISIT(SymEngine::RealDouble) 513 IMPLEMENT_DSV_BVISIT(SymEngine::ComplexDouble) 514 IMPLEMENT_DSV_BVISIT(SymEngine::Add) 515 IMPLEMENT_DSV_BVISIT(SymEngine::Mul) 516 IMPLEMENT_DSV_BVISIT(SymEngine::Pow) 517 IMPLEMENT_DSV_BVISIT(SymEngine::Log) 518 IMPLEMENT_DSV_BVISIT(SymEngine::Sin) 519 IMPLEMENT_DSV_BVISIT(SymEngine::Cos) 520 IMPLEMENT_DSV_BVISIT(SymEngine::Tan) 521 IMPLEMENT_DSV_BVISIT(SymEngine::Csc) 522 IMPLEMENT_DSV_BVISIT(SymEngine::Sec) 523 IMPLEMENT_DSV_BVISIT(SymEngine::Cot) 524 IMPLEMENT_DSV_BVISIT(SymEngine::ASin) 525 IMPLEMENT_DSV_BVISIT(SymEngine::ACos) 526 IMPLEMENT_DSV_BVISIT(SymEngine::ATan) 527 IMPLEMENT_DSV_BVISIT(SymEngine::ATan2) 528 IMPLEMENT_DSV_BVISIT(SymEngine::ACsc) 529 IMPLEMENT_DSV_BVISIT(SymEngine::ASec) 530 IMPLEMENT_DSV_BVISIT(SymEngine::ACot) 531 IMPLEMENT_DSV_BVISIT(SymEngine::Sinh) 532 IMPLEMENT_DSV_BVISIT(SymEngine::Cosh) 533 IMPLEMENT_DSV_BVISIT(SymEngine::Tanh) 534 IMPLEMENT_DSV_BVISIT(SymEngine::Csch) 535 IMPLEMENT_DSV_BVISIT(SymEngine::Sech) 536 IMPLEMENT_DSV_BVISIT(SymEngine::Coth) 537 IMPLEMENT_DSV_BVISIT(SymEngine::ASinh) 538 IMPLEMENT_DSV_BVISIT(SymEngine::ACosh) 539 IMPLEMENT_DSV_BVISIT(SymEngine::ATanh) 540 IMPLEMENT_DSV_BVISIT(SymEngine::ACsch) 541 IMPLEMENT_DSV_BVISIT(SymEngine::ACoth) 542 IMPLEMENT_DSV_BVISIT(SymEngine::ASech) 543 IMPLEMENT_DSV_BVISIT(SymEngine::Abs) 544 IMPLEMENT_DSV_BVISIT(SymEngine::Gamma) 545 IMPLEMENT_DSV_BVISIT(SymEngine::LogGamma) 546 IMPLEMENT_DSV_BVISIT(SymEngine::Erf) 547 IMPLEMENT_DSV_BVISIT(SymEngine::Erfc) 548 IMPLEMENT_DSV_BVISIT(SymEngine::Max) 549 IMPLEMENT_DSV_BVISIT(SymEngine::Min) 550 551 # undef IMPLEMENT_DSV_BVISIT 552 # endif // DOXYGEN 553 554 private: 555 // Note: It would be more efficient to store this data in native 556 // SymEngine types, as it would prevent some copying of the data 557 // structures. However, this makes serialization more difficult, 558 // so we use our own serializable types instead, and lose a bit 559 // of efficiency. 560 561 /** 562 * A vector of symbols that represent the independent variables. 563 */ 564 SD::types::symbol_vector independent_symbols; 565 566 /** 567 * A vector of expressions that represent dependent functions. 568 */ 569 SD::types::symbol_vector dependent_functions; 570 571 /** 572 * A data structure that may be used to invoke common subexpression 573 * elimination on the dependent functions, with the aim to decrease 574 * the time taken to evaluate them. 575 */ 576 CSEDictionaryVisitor<ReturnType, ExpressionType> cse; 577 }; 578 579 580 581 /* ------------------ inline and template functions ------------------ */ 582 583 584 # ifndef DOXYGEN 585 586 /* -------------- CommonSubexpressionEliminationVisitor -------------- */ 587 588 589 template <typename ReturnType, typename ExpressionType> 590 void init(const SD::types::symbol_vector & dependent_functions)591 CSEDictionaryVisitor<ReturnType, ExpressionType>::init( 592 const SD::types::symbol_vector &dependent_functions) 593 { 594 init(Utilities::convert_expression_vector_to_basic_vector( 595 dependent_functions)); 596 } 597 598 599 600 template <typename ReturnType, typename ExpressionType> 601 void init(const SymEngine::vec_basic & dependent_functions)602 CSEDictionaryVisitor<ReturnType, ExpressionType>::init( 603 const SymEngine::vec_basic &dependent_functions) 604 { 605 // After the next call, the data stored in replacements is structured 606 // as follows: 607 // 608 // replacements[i] := [f, f(x)] 609 // replacements[i].first = intermediate function label "f" 610 // replacements[i].second = intermediate function definition "f(x)" 611 // 612 // It is to be evaluated top down (i.e. index 0 to 613 // replacements.size()), with the results going back into the 614 // substitution map for the next levels. So for each "i", "x" are the 615 // superset of the input values and the previously evaluated [f_0(x), 616 // f_1(x), ..., f_{i-1}(x)]. 617 // 618 // The final result is a set of reduced expressions 619 // that must be computed after the replacement 620 // values have been computed. 621 SymEngine::vec_pair se_replacements; 622 SymEngine::vec_basic se_reduced_exprs; 623 SymEngine::cse(se_replacements, se_reduced_exprs, dependent_functions); 624 625 intermediate_symbols_exprs = 626 Utilities::convert_basic_pair_vector_to_expression_pair_vector( 627 se_replacements); 628 reduced_exprs = Utilities::convert_basic_vector_to_expression_vector( 629 se_reduced_exprs); 630 } 631 632 633 634 template <typename ReturnType, typename ExpressionType> 635 void call(ReturnType * output_values,const SD::types::symbol_vector & independent_symbols,const ReturnType * substitution_values)636 CSEDictionaryVisitor<ReturnType, ExpressionType>::call( 637 ReturnType * output_values, 638 const SD::types::symbol_vector &independent_symbols, 639 const ReturnType * substitution_values) 640 { 641 call(output_values, 642 Utilities::convert_expression_vector_to_basic_vector( 643 independent_symbols), 644 substitution_values); 645 } 646 647 648 649 template <typename ReturnType, typename ExpressionType> 650 void call(ReturnType * output_values,const SymEngine::vec_basic & independent_symbols,const ReturnType * substitution_values)651 CSEDictionaryVisitor<ReturnType, ExpressionType>::call( 652 ReturnType * output_values, 653 const SymEngine::vec_basic &independent_symbols, 654 const ReturnType * substitution_values) 655 { 656 Assert(n_reduced_expressions() > 0, ExcInternalError()); 657 658 // First we add the input values into the substitution map... 659 SymEngine::map_basic_basic substitution_value_map; 660 for (unsigned i = 0; i < independent_symbols.size(); ++i) 661 substitution_value_map[independent_symbols[i]] = 662 static_cast<const SymEngine::RCP<const SymEngine::Basic> &>( 663 ExpressionType(substitution_values[i])); 664 665 // ... followed by any intermediate evaluations due to the application 666 // of CSE. These are fed directly back into the substitution map... 667 for (unsigned i = 0; i < intermediate_symbols_exprs.size(); ++i) 668 { 669 const SymEngine::RCP<const SymEngine::Basic> &cse_symbol = 670 intermediate_symbols_exprs[i].first; 671 const SymEngine::RCP<const SymEngine::Basic> &cse_expr = 672 intermediate_symbols_exprs[i].second; 673 Assert(substitution_value_map.find(cse_symbol) == 674 substitution_value_map.end(), 675 ExcMessage( 676 "Reduced symbol already appears in substitution map. " 677 "Is there a clash between the reduced symbol name and " 678 "the symbol used for an independent variable?")); 679 substitution_value_map[cse_symbol] = 680 static_cast<const SymEngine::RCP<const SymEngine::Basic> &>( 681 ExpressionType(ExpressionType(cse_expr) 682 .template substitute_and_evaluate<ReturnType>( 683 substitution_value_map))); 684 } 685 686 // ... followed by the final reduced expressions 687 for (unsigned i = 0; i < reduced_exprs.size(); ++i) 688 output_values[i] = ExpressionType(reduced_exprs[i]) 689 .template substitute_and_evaluate<ReturnType>( 690 substitution_value_map); 691 } 692 693 694 695 template <typename ReturnType, typename ExpressionType> 696 template <class Archive> 697 void save(Archive & ar,const unsigned int)698 CSEDictionaryVisitor<ReturnType, ExpressionType>::save( 699 Archive &ar, 700 const unsigned int /*version*/) const 701 { 702 // The reduced expressions depend on the intermediate expressions, 703 // so we serialize the latter before the former. 704 ar &intermediate_symbols_exprs; 705 ar &reduced_exprs; 706 } 707 708 709 710 template <typename ReturnType, typename ExpressionType> 711 template <class Archive> 712 void load(Archive & ar,const unsigned int)713 CSEDictionaryVisitor<ReturnType, ExpressionType>::load( 714 Archive &ar, 715 const unsigned int /*version*/) 716 { 717 Assert(intermediate_symbols_exprs.empty(), ExcInternalError()); 718 Assert(reduced_exprs.empty(), ExcInternalError()); 719 720 // The reduced expressions depend on the intermediate expressions, 721 // so we deserialize the latter before the former. 722 ar &intermediate_symbols_exprs; 723 ar &reduced_exprs; 724 } 725 726 727 728 template <typename ReturnType, typename ExpressionType> 729 template <typename StreamType> 730 void print(StreamType & stream)731 CSEDictionaryVisitor<ReturnType, ExpressionType>::print( 732 StreamType &stream) const 733 { 734 stream << "Common subexpression elimination: \n"; 735 stream << " Intermediate reduced expressions: \n"; 736 for (unsigned i = 0; i < intermediate_symbols_exprs.size(); ++i) 737 { 738 const SymEngine::RCP<const SymEngine::Basic> &cse_symbol = 739 intermediate_symbols_exprs[i].first; 740 const SymEngine::RCP<const SymEngine::Basic> &cse_expr = 741 intermediate_symbols_exprs[i].second; 742 stream << " " << i << ": " << cse_symbol << " = " << cse_expr 743 << "\n"; 744 } 745 746 stream << " Final reduced expressions for dependent variables: \n"; 747 for (unsigned i = 0; i < reduced_exprs.size(); ++i) 748 stream << " " << i << ": " << reduced_exprs[i] << "\n"; 749 750 stream << std::flush; 751 } 752 753 754 755 template <typename ReturnType, typename ExpressionType> 756 bool executed()757 CSEDictionaryVisitor<ReturnType, ExpressionType>::executed() const 758 { 759 // For dictionary substitution, the CSE algorithm moves 760 // ownership of the dependent function expression definition 761 // to the entries in reduced_exprs. So its size thus determines 762 // whether CSE has been executed or not. 763 return (n_reduced_expressions() > 0) || 764 (n_intermediate_expressions() > 0); 765 } 766 767 768 769 template <typename ReturnType, typename ExpressionType> 770 unsigned int 771 CSEDictionaryVisitor<ReturnType, n_intermediate_expressions()772 ExpressionType>::n_intermediate_expressions() const 773 { 774 return intermediate_symbols_exprs.size(); 775 } 776 777 778 779 template <typename ReturnType, typename ExpressionType> 780 unsigned int n_reduced_expressions()781 CSEDictionaryVisitor<ReturnType, ExpressionType>::n_reduced_expressions() 782 const 783 { 784 return reduced_exprs.size(); 785 } 786 787 788 789 /* ------------------ DictionarySubstitutionVisitor ------------------ */ 790 791 792 template <typename ReturnType, typename ExpressionType> 793 void init(const types::symbol_vector & inputs,const SD::Expression & output,const bool use_cse)794 DictionarySubstitutionVisitor<ReturnType, ExpressionType>::init( 795 const types::symbol_vector &inputs, 796 const SD::Expression & output, 797 const bool use_cse) 798 { 799 init(inputs, types::symbol_vector{output}, use_cse); 800 } 801 802 803 804 template <typename ReturnType, typename ExpressionType> 805 void init(const SymEngine::vec_basic & inputs,const SymEngine::Basic & output,const bool use_cse)806 DictionarySubstitutionVisitor<ReturnType, ExpressionType>::init( 807 const SymEngine::vec_basic &inputs, 808 const SymEngine::Basic & output, 809 const bool use_cse) 810 { 811 init(Utilities::convert_basic_vector_to_expression_vector(inputs), 812 SD::Expression(output.rcp_from_this()), 813 use_cse); 814 } 815 816 817 818 template <typename ReturnType, typename ExpressionType> 819 void init(const SymEngine::vec_basic & inputs,const SymEngine::vec_basic & outputs,const bool use_cse)820 DictionarySubstitutionVisitor<ReturnType, ExpressionType>::init( 821 const SymEngine::vec_basic &inputs, 822 const SymEngine::vec_basic &outputs, 823 const bool use_cse) 824 { 825 init(Utilities::convert_basic_vector_to_expression_vector(inputs), 826 Utilities::convert_basic_vector_to_expression_vector(outputs), 827 use_cse); 828 } 829 830 831 832 template <typename ReturnType, typename ExpressionType> 833 void init(const types::symbol_vector & inputs,const types::symbol_vector & outputs,const bool use_cse)834 DictionarySubstitutionVisitor<ReturnType, ExpressionType>::init( 835 const types::symbol_vector &inputs, 836 const types::symbol_vector &outputs, 837 const bool use_cse) 838 { 839 independent_symbols.clear(); 840 dependent_functions.clear(); 841 842 independent_symbols = inputs; 843 844 // Perform common subexpression elimination if requested 845 // Note: After this is done, the results produced by 846 // dependent_functions and cse.reduced_exprs should be 847 // the same. We could keep the former so that we can print 848 // out the original expressions if we wish to do so. 849 if (use_cse == false) 850 dependent_functions = outputs; 851 else 852 { 853 cse.init(outputs); 854 } 855 } 856 857 858 859 template <typename ReturnType, typename ExpressionType> 860 ReturnType call(const std::vector<ReturnType> & substitution_values)861 DictionarySubstitutionVisitor<ReturnType, ExpressionType>::call( 862 const std::vector<ReturnType> &substitution_values) 863 { 864 Assert( 865 dependent_functions.size() == 1, 866 ExcMessage( 867 "Cannot use this call function when more than one symbolic expression is to be evaluated.")); 868 Assert( 869 substitution_values.size() == independent_symbols.size(), 870 ExcMessage( 871 "Input substitution vector does not match size of symbol vector.")); 872 873 ReturnType out = dealii::internal::NumberType<ReturnType>::value(0.0); 874 call(&out, substitution_values.data()); 875 return out; 876 } 877 878 879 880 template <typename ReturnType, typename ExpressionType> 881 void call(ReturnType * output_values,const ReturnType * substitution_values)882 DictionarySubstitutionVisitor<ReturnType, ExpressionType>::call( 883 ReturnType * output_values, 884 const ReturnType *substitution_values) 885 { 886 // Check to see if CSE has been performed 887 if (cse.executed()) 888 { 889 cse.call(output_values, independent_symbols, substitution_values); 890 } 891 else 892 { 893 // Build a substitution map. 894 SymEngine::map_basic_basic substitution_value_map; 895 for (unsigned i = 0; i < independent_symbols.size(); ++i) 896 substitution_value_map[independent_symbols[i]] = 897 static_cast<const SymEngine::RCP<const SymEngine::Basic> &>( 898 ExpressionType(substitution_values[i])); 899 900 // Since we don't know how to definitively evaluate the 901 // input number type, we create a generic Expression 902 // with the given symbolic expression and ask it to perform 903 // substitution and evaluation for us. 904 Assert(dependent_functions.size() > 0, ExcInternalError()); 905 for (unsigned i = 0; i < dependent_functions.size(); ++i) 906 output_values[i] = 907 ExpressionType(dependent_functions[i]) 908 .template substitute_and_evaluate<ReturnType>( 909 substitution_value_map); 910 } 911 } 912 913 914 915 template <typename ReturnType, typename ExpressionType> 916 template <class Archive> 917 void save(Archive & ar,const unsigned int version)918 DictionarySubstitutionVisitor<ReturnType, ExpressionType>::save( 919 Archive & ar, 920 const unsigned int version) const 921 { 922 // Add some dynamic information to determine if CSE has been used, 923 // without relying on the CSE class when deserializing. 924 // const bool used_cse = cse.executed(); 925 // ar &used_cse; 926 927 // CSE and dependent variables both require the independent 928 // symbols, so we serialize them first. The dependent variables 929 // might depend on the outcome of CSE, so we have to serialize 930 // them last. 931 ar &independent_symbols; 932 cse.save(ar, version); 933 ar &dependent_functions; 934 } 935 936 937 938 template <typename ReturnType, typename ExpressionType> 939 template <class Archive> 940 void load(Archive & ar,const unsigned int version)941 DictionarySubstitutionVisitor<ReturnType, ExpressionType>::load( 942 Archive & ar, 943 const unsigned int version) 944 { 945 Assert(cse.executed() == false, ExcInternalError()); 946 Assert(cse.n_intermediate_expressions() == 0, ExcInternalError()); 947 Assert(cse.n_reduced_expressions() == 0, ExcInternalError()); 948 949 // CSE and dependent variables both require the independent 950 // symbols, so we deserialize them first. The dependent variables 951 // might depend on the outcome of CSE, so we have to deserialize 952 // them last. 953 ar &independent_symbols; 954 cse.load(ar, version); 955 ar &dependent_functions; 956 } 957 958 959 960 template <typename ReturnType, typename ExpressionType> 961 template <typename StreamType> 962 void print(StreamType & stream,const bool print_independent_symbols,const bool print_dependent_functions,const bool print_cse_reductions)963 DictionarySubstitutionVisitor<ReturnType, ExpressionType>::print( 964 StreamType &stream, 965 const bool print_independent_symbols, 966 const bool print_dependent_functions, 967 const bool print_cse_reductions) const 968 { 969 if (print_independent_symbols) 970 { 971 stream << "Independent variables: \n"; 972 for (unsigned i = 0; i < independent_symbols.size(); ++i) 973 stream << " " << i << ": " << independent_symbols[i] << "\n"; 974 975 stream << std::flush; 976 } 977 978 // Check to see if CSE has been performed 979 if (print_cse_reductions && cse.executed()) 980 { 981 cse.print(stream); 982 } 983 else 984 { 985 Assert(dependent_functions.size() > 0, ExcInternalError()); 986 987 if (print_dependent_functions) 988 { 989 stream << "Dependent variables: \n"; 990 for (unsigned i = 0; i < dependent_functions.size(); ++i) 991 stream << " " << i << dependent_functions[i] << "\n"; 992 993 stream << std::flush; 994 } 995 } 996 } 997 998 # endif // DOXYGEN 999 1000 } // namespace internal 1001 } // namespace SD 1002 } // namespace Differentiation 1003 1004 1005 DEAL_II_NAMESPACE_CLOSE 1006 1007 #endif // DEAL_II_WITH_SYMENGINE 1008 1009 #endif // dealii_differentiation_sd_symengine_number_visitor_internal_h 1010