1 // ---------------------------------------------------------------------
2 //
3 // Copyright (C) 2020 by the deal.II authors
4 //
5 // This file is part of the deal.II library.
6 //
7 // The deal.II library is free software; you can use it, redistribute
8 // it, and/or modify it under the terms of the GNU Lesser General
9 // Public License as published by the Free Software Foundation; either
10 // version 2.1 of the License, or (at your option) any later version.
11 // The full text of the license can be found in the file LICENSE at
12 // the top level of the deal.II distribution.
13 //
14 // ---------------------------------------------------------------------
15 
16 #ifndef dealii_differentiation_sd_symengine_number_visitor_internal_h
17 #define dealii_differentiation_sd_symengine_number_visitor_internal_h
18 
19 #include <deal.II/base/config.h>
20 
21 #ifdef DEAL_II_WITH_SYMENGINE
22 
23 DEAL_II_DISABLE_EXTRA_DIAGNOSTICS
24 // Low level
25 #  include <symengine/basic.h>
26 #  include <symengine/dict.h>
27 #  include <symengine/symengine_exception.h>
28 #  include <symengine/symengine_rcp.h>
29 
30 // Visitor
31 #  include <symengine/visitor.h>
32 
33 DEAL_II_ENABLE_EXTRA_DIAGNOSTICS
34 
35 #  include <deal.II/base/exceptions.h>
36 #  include <deal.II/base/numbers.h>
37 
38 #  include <deal.II/differentiation/sd/symengine_number_types.h>
39 #  include <deal.II/differentiation/sd/symengine_utilities.h>
40 
41 #  include <boost/serialization/split_member.hpp>
42 
43 
44 DEAL_II_NAMESPACE_OPEN
45 
46 
47 namespace Differentiation
48 {
49   namespace SD
50   {
51     namespace internal
52     {
53       /**
54        * A class that implements common subexpression elimination
55        * for dictionary visitor classes.
56        *
57        * It is intended that this class only be used in conjunction
58        * with the DictionarySubstitutionVisitor.
59        */
60       template <typename ReturnType, typename ExpressionType>
61       class CSEDictionaryVisitor
62       {
63         using symbol_vector_pair =
64           std::vector<std::pair<SD::Expression, SD::Expression>>;
65 
66       public:
67         /*
68          * Constructor.
69          */
70         CSEDictionaryVisitor() = default;
71 
72         /*
73          * Destructor.
74          */
75         virtual ~CSEDictionaryVisitor() = default;
76 
77         /**
78          * Initialize and perform common subexpression elimination.
79          *
80          * Here we build the reduced expressions for the @p dependent_functions
81          * as well as a list of intermediate, repeating symbolic expressions
82          * that are extracted @p dependent_functions. This operation leads to
83          * the elimination of repeated expressions, so they only have to be
84          * evaluated once.
85          *
86          * @param dependent_functions A vector of expressions that represent
87          * the dependent functions that CSE will be performed on.
88          */
89         void
90         init(const types::symbol_vector &dependent_functions);
91 
92         /**
93          * Evaluate the (reduced) dependent functions using the common
94          * subexpressions generated during the call to init().
95          *
96          * The @p output_values are the numerical result of substituting
97          * each of the @p substitution_values for their corresponding entry
98          * of the @p independent_symbols. Specifically, we first calculate
99          * the numerical values of each reduced subexpression, and feed those
100          * into a reduced equivalent of the dependent variables. The result of
101          * this is then written into the @p output_values array.
102          *
103          * @param[out] output_values A pointer to the first element in an array
104          * of type @p ReturnType. After this call, the underlying array will
105          * hold the numerical result of substituting the @p substitution_values
106          * into the pre-registered dependent functions.
107          * @param[in]  independent_symbols A vector of symbols that represent
108          * the independent variables that are arguments to the previously
109          * defined dependent functions.
110          * @param[in]  substitution_values A pointer to the first element in an
111          * array of type @p ReturnType. Each entry in this array stores the
112          * numerical value that an independent variable is to take for the
113          * purpose of value substitution.
114          *
115          * @note It is expected that both the @p output_values and
116          * @p substitution_values arrays be correctly dimensioned, as there is
117          * no range checking performed on these data structures.
118          * The @p output_values array should have the same number of elements as
119          * the dependent functions first passed to this class in the call to
120          * init(). Similarly, the @p substitution_values array should have the
121          * same number of elements as the @p independent_symbols vector has.
122          *
123          * @note It is expected that there be a 1-1 correspondence between the
124          * entries in @p independent_symbols and @p substitution_values. This
125          * is not checked within this function, so it is up to the user to
126          * ensure that the relationships between the independent variables and
127          * their numerical value be correctly set up and maintained.
128          */
129         void
130         call(ReturnType *                output_values,
131              const types::symbol_vector &independent_symbols,
132              const ReturnType *          substitution_values);
133 
134         /**
135          * Write the data of this object to a stream for the purpose
136          * of serialization.
137          */
138         template <class Archive>
139         void
140         save(Archive &archive, const unsigned int version) const;
141 
142         /**
143          * Read the data for this object from a stream for the purpose
144          * of serialization.
145          */
146         template <class Archive>
147         void
148         load(Archive &archive, const unsigned int version);
149 
150 #  ifdef DOXYGEN
151         /**
152          * Write and read the data of this object from a stream for the purpose
153          * of serialization.
154          */
155         template <class Archive>
156         void
157         serialize(Archive &archive, const unsigned int version);
158 #  else
159         // This macro defines the serialize() method that is compatible with
160         // the templated save() and load() method that have been implemented.
161         BOOST_SERIALIZATION_SPLIT_MEMBER()
162 #  endif
163 
164         /**
165          * Print all of the intermediate reduced expressions, as well as
166          * the final reduced expressions for dependent variables to the
167          * @p stream.
168          */
169         template <typename StreamType>
170         void
171         print(StreamType &stream) const;
172 
173         /**
174          * Return a flag stating whether we've performed CSE or not.
175          */
176         bool
177         executed() const;
178 
179         /**
180          * The number of intermediate expressions that must be evaluated as
181          * part of the collection of common subexpressions.
182          */
183         unsigned int
184         n_intermediate_expressions() const;
185 
186         /**
187          * The size of the final set of reduced expressions.
188          */
189         unsigned int
190         n_reduced_expressions() const;
191 
192       protected:
193         /**
194          * Initialize and perform common subexpression elimination.
195          *
196          * This function performs the same action as the other init() function,
197          * except that it works with native SymEngine data types;
198          * the single argument is  a vector of
199          * `SymEngine::RCP<const SymEngine::Basic>`.
200          *
201          * @param dependent_functions A vector of expressions that represent
202          * the dependent functions that CSE will be performed on.
203          */
204         void
205         init(const SymEngine::vec_basic &dependent_functions);
206 
207         /**
208          * Evaluate the (reduced) dependent functions using the common
209          * subexpressions generated during the call to init().
210          *
211          * This function performs the same action as the other init() function,
212          * except that it works with native SymEngine data types.
213          * The @p independent_symbols is represented by a vector of
214          * `SymEngine::RCP<const SymEngine::Basic>`.
215          *
216          * @param[out] output_values A pointer to the first element in an array
217          * of type @p ReturnType. After this call, the underlying array will
218          * hold the numerical result of substituting the @p substitution_values
219          * into the pre-registered dependent functions.
220          * @param[in]  independent_symbols A vector of symbols that represent
221          * the independent variables that are arguments to the previously
222          * defined dependent functions.
223          * @param[in]  substitution_values A pointer to the first element in an
224          * array of type @p ReturnType. Each entry in this array stores the
225          * numerical value that an independent variable is to take for the
226          * purpose of value substitution.
227          *
228          * @note The caveats described in the documentation of the other call()
229          * function apply here as well.
230          */
231         void
232         call(ReturnType *                output_values,
233              const SymEngine::vec_basic &independent_symbols,
234              const ReturnType *          substitution_values);
235 
236       private:
237         // Note: It would be more efficient to store this data in native
238         // SymEngine types, as it would prevent some copying of the data
239         // structures. However, this makes serialization more difficult,
240         // so we use our own serializable types instead, and lose a bit
241         // of efficiency.
242 
243         /**
244          * Intermediate symbols and their definition.
245          */
246         symbol_vector_pair intermediate_symbols_exprs;
247 
248         /**
249          * Final reduced expressions.
250          */
251         types::symbol_vector reduced_exprs;
252       };
253 
254 
255 
256       /**
257        * A class to perform dictionary-based substitution as if it were an
258        * optimizer of the "lambda" or "LLVM" variety.
259        *
260        * This class is only really useful to assist in the easy switching
261        * between different optimizers and, more importantly, for integrating
262        * CSE into a dictionary substitution scheme. It is therefore only
263        * intended to be created and used by a BatchOptimizer.
264        */
265       template <typename ReturnType, typename ExpressionType>
266       class DictionarySubstitutionVisitor
267         : public SymEngine::BaseVisitor<
268             DictionarySubstitutionVisitor<ReturnType, ExpressionType>>
269       {
270       public:
271         /*
272          * Constructor.
273          */
274         DictionarySubstitutionVisitor() = default;
275 
276         /*
277          * Destructor.
278          */
279         virtual ~DictionarySubstitutionVisitor() override = default;
280 
281         /**
282          * Initialization, and registration of the independent and dependent
283          * variables.
284          *
285          * This variation of the initialization function registers a single
286          * dependent expression. If the @p use_cse is set to `true`, then common
287          * subexpression elimination is also performed at the time of
288          * initialization.
289          *
290          * @param independent_symbols A vector of symbols that represent
291          * the independent variables that are arguments to the
292          * @p dependent_function.
293          * @param dependent_function A single symbolic expression that
294          * represents a dependent variable.
295          * @param use_cse A flag to indicate whether or not to use common
296          * subexpression elimination.
297          *
298          * @note After this function is called, no further registration of
299          * dependent functions can be performed. If it is desired that multiple
300          * dependent expressions be registered, then the other variant of this
301          * function that takes in a vector of dependent expressions should be
302          * used.
303          */
304         void
305         init(const types::symbol_vector &independent_symbols,
306              const Expression &          dependent_function,
307              const bool                  use_cse = false);
308 
309         /**
310          * Initialization, and registration of the independent and dependent
311          * variables.
312          *
313          * This function performs the same action as the other init() function,
314          * described above, except that it works with native SymEngine data
315          * types. The @p independent_symbols are represented by a vector of
316          * `SymEngine::RCP<const SymEngine::Basic>`, and the
317          * @p dependent_function is of type
318          * `SymEngine::RCP<const SymEngine::Basic>`.
319          *
320          * @param independent_symbols A vector of symbols that represent
321          * the independent variables that are arguments to the
322          * @p dependent_function.
323          * @param dependent_function A single symbolic expression that
324          * represents a dependent variable.
325          * @param use_cse A flag to indicate whether or not to use common
326          * subexpression elimination.
327          *
328          * @note The caveats described in the documentation of the other init()
329          * function apply here as well.
330          */
331         // The following definition is required due to base class CRTP.
332         void
333         init(const SymEngine::vec_basic &independent_symbols,
334              const SymEngine::Basic &    dependent_function,
335              const bool                  use_cse = false);
336 
337         /**
338          * Initialization, and registration of the independent and dependent
339          * variables.
340          *
341          * This variation of the initialization function registers a vector of
342          * dependent expressions. If the @p use_cse is set to `true`, then
343          * common subexpression elimination is also performed at the time of
344          * initialization.
345          *
346          * @param independent_symbols A vector of symbols that represent
347          * the independent variables that are arguments to the
348          * @p dependent_functions.
349          * @param dependent_functions A vector of symbolic expressions that
350          * represent the dependent variables.
351          * @param use_cse A flag to indicate whether or not to use common
352          * subexpression elimination.
353          *
354          * @note After this function is called, no further registration of
355          * dependent functions can be performed.
356          */
357         void
358         init(const types::symbol_vector &independent_symbols,
359              const types::symbol_vector &dependent_functions,
360              const bool                  use_cse = false);
361 
362 
363         /**
364          * Initialization, and registration of the independent and dependent
365          * variables.
366          *
367          * This function performs the same action as the other init() function,
368          * described above, except that it works with native SymEngine data
369          * types. Both the @p independent_symbols and @p dependent_functions
370          * are represented by a vector of
371          * `SymEngine::RCP<const SymEngine::Basic>`.
372          *
373          * @param independent_symbols A vector of symbols that represent
374          * the independent variables that are arguments to the
375          * @p dependent_functions.
376          * @param dependent_functions A vector of symbolic expressions that
377          * represent the dependent variables.
378          * @param use_cse A flag to indicate whether or not to use common
379          * subexpression elimination.
380          *
381          * @note The caveats described in the documentation of the other init()
382          * function apply here as well.
383          */
384         // The following definition is required due to base class CRTP.
385         void
386         init(const SymEngine::vec_basic &independent_symbols,
387              const SymEngine::vec_basic &dependent_functions,
388              const bool                  use_cse = false);
389 
390         /**
391          * Evaluate the dependent functions that were registered at
392          * initializtion time.
393          *
394          * The @p output_values are the numerical result of substituting
395          * each of the @p substitution_values for their corresponding entry
396          * of the pre-registered independent variables.
397          *
398          * @param[out] output_values A pointer to the first element in an array
399          * of type @p ReturnType. After this call, the underlying array will
400          * hold the numerical result of substituting the @p substitution_values
401          * into the pre-registered dependent functions.
402          * @param[in]  substitution_values A pointer to the first element in an
403          * array of type @p ReturnType. Each entry in this array stores the
404          * numerical value that an independent variable is to take for the
405          * purpose of value substitution.
406          *
407          * @note It is expected that both the @p output_values and
408          * @p substitution_values arrays be correctly dimensioned, as there is
409          * no range checking performed on these data structures.
410          * The @p output_values array should have the same number of elements as
411          * the dependent functions first passed to this class in the call to
412          * init(). Similarly, the @p substitution_values array should have the
413          * same number of elements as the there were independent variables
414          * that were registered at the time of initialization.
415          *
416          * @note It is expected that there be a 1-1 correspondence between the
417          * entries in independent variables and @p substitution_values. This
418          * is not checked within this function, so it is up to the user to
419          * ensure that the relationships between the independent variables and
420          * their numerical value be correctly set up and maintained.
421          */
422         void
423         call(ReturnType *output_values, const ReturnType *substitution_values);
424 
425         /**
426          * Evaluate the dependent function that were registered at
427          * initializtion time.
428          *
429          * The purpose of this function is the same as the other call()
430          * functions, but
431          *
432          * @param substitution_values A vector that stores the numerical values
433          * that each independent variable that was registered with the class
434          * instance is to take for the purpose of value substitution.
435          * @return ReturnType The numerical value associated with the single
436          * dependent function that is registered with this class.
437          *
438          * @note It is expected that there be a 1-1 correspondence between the
439          * entries in independent variables and @p substitution_values.
440          */
441         // The following definition is required due to base class CRTP.
442         ReturnType
443         call(const std::vector<ReturnType> &substitution_values);
444 
445         /**
446          * Write the data of this object to a stream for the purpose
447          * of serialization.
448          */
449         template <class Archive>
450         void
451         save(Archive &archive, const unsigned int version) const;
452 
453         /**
454          * Read the data for this object from a stream for the purpose
455          * of serialization.
456          */
457         template <class Archive>
458         void
459         load(Archive &archive, const unsigned int version);
460 
461 #  ifdef DOXYGEN
462         /**
463          * Write and read the data of this object from a stream for the purpose
464          * of serialization.
465          */
466         template <class Archive>
467         void
468         serialize(Archive &archive, const unsigned int version);
469 #  else
470         // This macro defines the serialize() method that is compatible with
471         // the templated save() and load() method that have been implemented.
472         BOOST_SERIALIZATION_SPLIT_MEMBER()
473 #  endif
474 
475         /**
476          * Print some information on state of the internal data
477          * structures stored in the class.
478          *
479          * @tparam Stream The type for the output stream.
480          * @param stream The output stream to print to.
481          * @param print_independent_symbols A flag to indicate if the independent
482          * variables should be outputted to the @p stream.
483          * @param print_dependent_functions A flag to indicate if the dependent
484          * expressions should be outputted to the @p stream.
485          * @param print_cse_reductions A flag to indicate whether or not all
486          * common subexpressions should be printed to the @p stream.
487          */
488         template <typename StreamType>
489         void
490         print(StreamType &stream,
491               const bool  print_independent_symbols = false,
492               const bool  print_dependent_functions = false,
493               const bool  print_cse_reductions      = false) const;
494 
495 #  ifndef DOXYGEN
496         // The following definitions are required due to base class CRTP.
497         // Since these are not used, and therefore not important to
498         // understand, we'll define them in the most concise manner possible.
499         // We also won't bother to document their existence, since they cannot
500         // be used.
501 #    define IMPLEMENT_DSV_BVISIT(Argument)       \
502       void bvisit(const Argument &)              \
503       {                                          \
504         AssertThrow(false, ExcNotImplemented()); \
505       }
506 
507         IMPLEMENT_DSV_BVISIT(SymEngine::Basic)
508         IMPLEMENT_DSV_BVISIT(SymEngine::Symbol)
509         IMPLEMENT_DSV_BVISIT(SymEngine::Constant)
510         IMPLEMENT_DSV_BVISIT(SymEngine::Integer)
511         IMPLEMENT_DSV_BVISIT(SymEngine::Rational)
512         IMPLEMENT_DSV_BVISIT(SymEngine::RealDouble)
513         IMPLEMENT_DSV_BVISIT(SymEngine::ComplexDouble)
514         IMPLEMENT_DSV_BVISIT(SymEngine::Add)
515         IMPLEMENT_DSV_BVISIT(SymEngine::Mul)
516         IMPLEMENT_DSV_BVISIT(SymEngine::Pow)
517         IMPLEMENT_DSV_BVISIT(SymEngine::Log)
518         IMPLEMENT_DSV_BVISIT(SymEngine::Sin)
519         IMPLEMENT_DSV_BVISIT(SymEngine::Cos)
520         IMPLEMENT_DSV_BVISIT(SymEngine::Tan)
521         IMPLEMENT_DSV_BVISIT(SymEngine::Csc)
522         IMPLEMENT_DSV_BVISIT(SymEngine::Sec)
523         IMPLEMENT_DSV_BVISIT(SymEngine::Cot)
524         IMPLEMENT_DSV_BVISIT(SymEngine::ASin)
525         IMPLEMENT_DSV_BVISIT(SymEngine::ACos)
526         IMPLEMENT_DSV_BVISIT(SymEngine::ATan)
527         IMPLEMENT_DSV_BVISIT(SymEngine::ATan2)
528         IMPLEMENT_DSV_BVISIT(SymEngine::ACsc)
529         IMPLEMENT_DSV_BVISIT(SymEngine::ASec)
530         IMPLEMENT_DSV_BVISIT(SymEngine::ACot)
531         IMPLEMENT_DSV_BVISIT(SymEngine::Sinh)
532         IMPLEMENT_DSV_BVISIT(SymEngine::Cosh)
533         IMPLEMENT_DSV_BVISIT(SymEngine::Tanh)
534         IMPLEMENT_DSV_BVISIT(SymEngine::Csch)
535         IMPLEMENT_DSV_BVISIT(SymEngine::Sech)
536         IMPLEMENT_DSV_BVISIT(SymEngine::Coth)
537         IMPLEMENT_DSV_BVISIT(SymEngine::ASinh)
538         IMPLEMENT_DSV_BVISIT(SymEngine::ACosh)
539         IMPLEMENT_DSV_BVISIT(SymEngine::ATanh)
540         IMPLEMENT_DSV_BVISIT(SymEngine::ACsch)
541         IMPLEMENT_DSV_BVISIT(SymEngine::ACoth)
542         IMPLEMENT_DSV_BVISIT(SymEngine::ASech)
543         IMPLEMENT_DSV_BVISIT(SymEngine::Abs)
544         IMPLEMENT_DSV_BVISIT(SymEngine::Gamma)
545         IMPLEMENT_DSV_BVISIT(SymEngine::LogGamma)
546         IMPLEMENT_DSV_BVISIT(SymEngine::Erf)
547         IMPLEMENT_DSV_BVISIT(SymEngine::Erfc)
548         IMPLEMENT_DSV_BVISIT(SymEngine::Max)
549         IMPLEMENT_DSV_BVISIT(SymEngine::Min)
550 
551 #    undef IMPLEMENT_DSV_BVISIT
552 #  endif // DOXYGEN
553 
554       private:
555         // Note: It would be more efficient to store this data in native
556         // SymEngine types, as it would prevent some copying of the data
557         // structures. However, this makes serialization more difficult,
558         // so we use our own serializable types instead, and lose a bit
559         // of efficiency.
560 
561         /**
562          * A vector of symbols that represent the independent variables.
563          */
564         SD::types::symbol_vector independent_symbols;
565 
566         /**
567          * A vector of expressions that represent dependent functions.
568          */
569         SD::types::symbol_vector dependent_functions;
570 
571         /**
572          * A data structure that may be used to invoke common subexpression
573          * elimination on the dependent functions, with the aim to decrease
574          * the time taken to evaluate them.
575          */
576         CSEDictionaryVisitor<ReturnType, ExpressionType> cse;
577       };
578 
579 
580 
581       /* ------------------ inline and template functions ------------------ */
582 
583 
584 #  ifndef DOXYGEN
585 
586       /* -------------- CommonSubexpressionEliminationVisitor -------------- */
587 
588 
589       template <typename ReturnType, typename ExpressionType>
590       void
init(const SD::types::symbol_vector & dependent_functions)591       CSEDictionaryVisitor<ReturnType, ExpressionType>::init(
592         const SD::types::symbol_vector &dependent_functions)
593       {
594         init(Utilities::convert_expression_vector_to_basic_vector(
595           dependent_functions));
596       }
597 
598 
599 
600       template <typename ReturnType, typename ExpressionType>
601       void
init(const SymEngine::vec_basic & dependent_functions)602       CSEDictionaryVisitor<ReturnType, ExpressionType>::init(
603         const SymEngine::vec_basic &dependent_functions)
604       {
605         // After the next call, the data stored in replacements is structured
606         // as follows:
607         //
608         // replacements[i] := [f, f(x)]
609         // replacements[i].first  = intermediate function label "f"
610         // replacements[i].second = intermediate function definition "f(x)"
611         //
612         // It is to be evaluated top down (i.e. index 0 to
613         // replacements.size()), with the results going back into the
614         // substitution map for the next levels. So for each "i", "x" are the
615         // superset of the input values and the previously evaluated [f_0(x),
616         // f_1(x), ..., f_{i-1}(x)].
617         //
618         // The final result is a set of reduced expressions
619         // that must be computed after the replacement
620         // values have been computed.
621         SymEngine::vec_pair  se_replacements;
622         SymEngine::vec_basic se_reduced_exprs;
623         SymEngine::cse(se_replacements, se_reduced_exprs, dependent_functions);
624 
625         intermediate_symbols_exprs =
626           Utilities::convert_basic_pair_vector_to_expression_pair_vector(
627             se_replacements);
628         reduced_exprs = Utilities::convert_basic_vector_to_expression_vector(
629           se_reduced_exprs);
630       }
631 
632 
633 
634       template <typename ReturnType, typename ExpressionType>
635       void
call(ReturnType * output_values,const SD::types::symbol_vector & independent_symbols,const ReturnType * substitution_values)636       CSEDictionaryVisitor<ReturnType, ExpressionType>::call(
637         ReturnType *                    output_values,
638         const SD::types::symbol_vector &independent_symbols,
639         const ReturnType *              substitution_values)
640       {
641         call(output_values,
642              Utilities::convert_expression_vector_to_basic_vector(
643                independent_symbols),
644              substitution_values);
645       }
646 
647 
648 
649       template <typename ReturnType, typename ExpressionType>
650       void
call(ReturnType * output_values,const SymEngine::vec_basic & independent_symbols,const ReturnType * substitution_values)651       CSEDictionaryVisitor<ReturnType, ExpressionType>::call(
652         ReturnType *                output_values,
653         const SymEngine::vec_basic &independent_symbols,
654         const ReturnType *          substitution_values)
655       {
656         Assert(n_reduced_expressions() > 0, ExcInternalError());
657 
658         // First we add the input values into the substitution map...
659         SymEngine::map_basic_basic substitution_value_map;
660         for (unsigned i = 0; i < independent_symbols.size(); ++i)
661           substitution_value_map[independent_symbols[i]] =
662             static_cast<const SymEngine::RCP<const SymEngine::Basic> &>(
663               ExpressionType(substitution_values[i]));
664 
665         // ... followed by any intermediate evaluations due to the application
666         // of CSE. These are fed directly back into the substitution map...
667         for (unsigned i = 0; i < intermediate_symbols_exprs.size(); ++i)
668           {
669             const SymEngine::RCP<const SymEngine::Basic> &cse_symbol =
670               intermediate_symbols_exprs[i].first;
671             const SymEngine::RCP<const SymEngine::Basic> &cse_expr =
672               intermediate_symbols_exprs[i].second;
673             Assert(substitution_value_map.find(cse_symbol) ==
674                      substitution_value_map.end(),
675                    ExcMessage(
676                      "Reduced symbol already appears in substitution map. "
677                      "Is there a clash between the reduced symbol name and "
678                      "the symbol used for an independent variable?"));
679             substitution_value_map[cse_symbol] =
680               static_cast<const SymEngine::RCP<const SymEngine::Basic> &>(
681                 ExpressionType(ExpressionType(cse_expr)
682                                  .template substitute_and_evaluate<ReturnType>(
683                                    substitution_value_map)));
684           }
685 
686         // ... followed by the final reduced expressions
687         for (unsigned i = 0; i < reduced_exprs.size(); ++i)
688           output_values[i] = ExpressionType(reduced_exprs[i])
689                                .template substitute_and_evaluate<ReturnType>(
690                                  substitution_value_map);
691       }
692 
693 
694 
695       template <typename ReturnType, typename ExpressionType>
696       template <class Archive>
697       void
save(Archive & ar,const unsigned int)698       CSEDictionaryVisitor<ReturnType, ExpressionType>::save(
699         Archive &ar,
700         const unsigned int /*version*/) const
701       {
702         // The reduced expressions depend on the intermediate expressions,
703         // so we serialize the latter before the former.
704         ar &intermediate_symbols_exprs;
705         ar &reduced_exprs;
706       }
707 
708 
709 
710       template <typename ReturnType, typename ExpressionType>
711       template <class Archive>
712       void
load(Archive & ar,const unsigned int)713       CSEDictionaryVisitor<ReturnType, ExpressionType>::load(
714         Archive &ar,
715         const unsigned int /*version*/)
716       {
717         Assert(intermediate_symbols_exprs.empty(), ExcInternalError());
718         Assert(reduced_exprs.empty(), ExcInternalError());
719 
720         // The reduced expressions depend on the intermediate expressions,
721         // so we deserialize the latter before the former.
722         ar &intermediate_symbols_exprs;
723         ar &reduced_exprs;
724       }
725 
726 
727 
728       template <typename ReturnType, typename ExpressionType>
729       template <typename StreamType>
730       void
print(StreamType & stream)731       CSEDictionaryVisitor<ReturnType, ExpressionType>::print(
732         StreamType &stream) const
733       {
734         stream << "Common subexpression elimination: \n";
735         stream << "  Intermediate reduced expressions: \n";
736         for (unsigned i = 0; i < intermediate_symbols_exprs.size(); ++i)
737           {
738             const SymEngine::RCP<const SymEngine::Basic> &cse_symbol =
739               intermediate_symbols_exprs[i].first;
740             const SymEngine::RCP<const SymEngine::Basic> &cse_expr =
741               intermediate_symbols_exprs[i].second;
742             stream << "  " << i << ": " << cse_symbol << " = " << cse_expr
743                    << "\n";
744           }
745 
746         stream << "  Final reduced expressions for dependent variables: \n";
747         for (unsigned i = 0; i < reduced_exprs.size(); ++i)
748           stream << "  " << i << ": " << reduced_exprs[i] << "\n";
749 
750         stream << std::flush;
751       }
752 
753 
754 
755       template <typename ReturnType, typename ExpressionType>
756       bool
executed()757       CSEDictionaryVisitor<ReturnType, ExpressionType>::executed() const
758       {
759         // For dictionary substitution, the CSE algorithm moves
760         // ownership of the dependent function expression definition
761         // to the entries in reduced_exprs. So its size thus determines
762         // whether CSE has been executed or not.
763         return (n_reduced_expressions() > 0) ||
764                (n_intermediate_expressions() > 0);
765       }
766 
767 
768 
769       template <typename ReturnType, typename ExpressionType>
770       unsigned int
771       CSEDictionaryVisitor<ReturnType,
n_intermediate_expressions()772                            ExpressionType>::n_intermediate_expressions() const
773       {
774         return intermediate_symbols_exprs.size();
775       }
776 
777 
778 
779       template <typename ReturnType, typename ExpressionType>
780       unsigned int
n_reduced_expressions()781       CSEDictionaryVisitor<ReturnType, ExpressionType>::n_reduced_expressions()
782         const
783       {
784         return reduced_exprs.size();
785       }
786 
787 
788 
789       /* ------------------ DictionarySubstitutionVisitor ------------------ */
790 
791 
792       template <typename ReturnType, typename ExpressionType>
793       void
init(const types::symbol_vector & inputs,const SD::Expression & output,const bool use_cse)794       DictionarySubstitutionVisitor<ReturnType, ExpressionType>::init(
795         const types::symbol_vector &inputs,
796         const SD::Expression &      output,
797         const bool                  use_cse)
798       {
799         init(inputs, types::symbol_vector{output}, use_cse);
800       }
801 
802 
803 
804       template <typename ReturnType, typename ExpressionType>
805       void
init(const SymEngine::vec_basic & inputs,const SymEngine::Basic & output,const bool use_cse)806       DictionarySubstitutionVisitor<ReturnType, ExpressionType>::init(
807         const SymEngine::vec_basic &inputs,
808         const SymEngine::Basic &    output,
809         const bool                  use_cse)
810       {
811         init(Utilities::convert_basic_vector_to_expression_vector(inputs),
812              SD::Expression(output.rcp_from_this()),
813              use_cse);
814       }
815 
816 
817 
818       template <typename ReturnType, typename ExpressionType>
819       void
init(const SymEngine::vec_basic & inputs,const SymEngine::vec_basic & outputs,const bool use_cse)820       DictionarySubstitutionVisitor<ReturnType, ExpressionType>::init(
821         const SymEngine::vec_basic &inputs,
822         const SymEngine::vec_basic &outputs,
823         const bool                  use_cse)
824       {
825         init(Utilities::convert_basic_vector_to_expression_vector(inputs),
826              Utilities::convert_basic_vector_to_expression_vector(outputs),
827              use_cse);
828       }
829 
830 
831 
832       template <typename ReturnType, typename ExpressionType>
833       void
init(const types::symbol_vector & inputs,const types::symbol_vector & outputs,const bool use_cse)834       DictionarySubstitutionVisitor<ReturnType, ExpressionType>::init(
835         const types::symbol_vector &inputs,
836         const types::symbol_vector &outputs,
837         const bool                  use_cse)
838       {
839         independent_symbols.clear();
840         dependent_functions.clear();
841 
842         independent_symbols = inputs;
843 
844         // Perform common subexpression elimination if requested
845         // Note: After this is done, the results produced by
846         // dependent_functions and cse.reduced_exprs should be
847         // the same. We could keep the former so that we can print
848         // out the original expressions if we wish to do so.
849         if (use_cse == false)
850           dependent_functions = outputs;
851         else
852           {
853             cse.init(outputs);
854           }
855       }
856 
857 
858 
859       template <typename ReturnType, typename ExpressionType>
860       ReturnType
call(const std::vector<ReturnType> & substitution_values)861       DictionarySubstitutionVisitor<ReturnType, ExpressionType>::call(
862         const std::vector<ReturnType> &substitution_values)
863       {
864         Assert(
865           dependent_functions.size() == 1,
866           ExcMessage(
867             "Cannot use this call function when more than one symbolic expression is to be evaluated."));
868         Assert(
869           substitution_values.size() == independent_symbols.size(),
870           ExcMessage(
871             "Input substitution vector does not match size of symbol vector."));
872 
873         ReturnType out = dealii::internal::NumberType<ReturnType>::value(0.0);
874         call(&out, substitution_values.data());
875         return out;
876       }
877 
878 
879 
880       template <typename ReturnType, typename ExpressionType>
881       void
call(ReturnType * output_values,const ReturnType * substitution_values)882       DictionarySubstitutionVisitor<ReturnType, ExpressionType>::call(
883         ReturnType *      output_values,
884         const ReturnType *substitution_values)
885       {
886         // Check to see if CSE has been performed
887         if (cse.executed())
888           {
889             cse.call(output_values, independent_symbols, substitution_values);
890           }
891         else
892           {
893             // Build a substitution map.
894             SymEngine::map_basic_basic substitution_value_map;
895             for (unsigned i = 0; i < independent_symbols.size(); ++i)
896               substitution_value_map[independent_symbols[i]] =
897                 static_cast<const SymEngine::RCP<const SymEngine::Basic> &>(
898                   ExpressionType(substitution_values[i]));
899 
900             // Since we don't know how to definitively evaluate the
901             // input number type, we create a generic Expression
902             // with the given symbolic expression and ask it to perform
903             // substitution and evaluation for us.
904             Assert(dependent_functions.size() > 0, ExcInternalError());
905             for (unsigned i = 0; i < dependent_functions.size(); ++i)
906               output_values[i] =
907                 ExpressionType(dependent_functions[i])
908                   .template substitute_and_evaluate<ReturnType>(
909                     substitution_value_map);
910           }
911       }
912 
913 
914 
915       template <typename ReturnType, typename ExpressionType>
916       template <class Archive>
917       void
save(Archive & ar,const unsigned int version)918       DictionarySubstitutionVisitor<ReturnType, ExpressionType>::save(
919         Archive &          ar,
920         const unsigned int version) const
921       {
922         // Add some dynamic information to determine if CSE has been used,
923         // without relying on the CSE class when deserializing.
924         // const bool used_cse = cse.executed();
925         // ar &used_cse;
926 
927         // CSE and dependent variables both require the independent
928         // symbols, so we serialize them first. The dependent variables
929         // might depend on the outcome of CSE, so we have to serialize
930         // them last.
931         ar &independent_symbols;
932         cse.save(ar, version);
933         ar &dependent_functions;
934       }
935 
936 
937 
938       template <typename ReturnType, typename ExpressionType>
939       template <class Archive>
940       void
load(Archive & ar,const unsigned int version)941       DictionarySubstitutionVisitor<ReturnType, ExpressionType>::load(
942         Archive &          ar,
943         const unsigned int version)
944       {
945         Assert(cse.executed() == false, ExcInternalError());
946         Assert(cse.n_intermediate_expressions() == 0, ExcInternalError());
947         Assert(cse.n_reduced_expressions() == 0, ExcInternalError());
948 
949         // CSE and dependent variables both require the independent
950         // symbols, so we deserialize them first. The dependent variables
951         // might depend on the outcome of CSE, so we have to deserialize
952         // them last.
953         ar &independent_symbols;
954         cse.load(ar, version);
955         ar &dependent_functions;
956       }
957 
958 
959 
960       template <typename ReturnType, typename ExpressionType>
961       template <typename StreamType>
962       void
print(StreamType & stream,const bool print_independent_symbols,const bool print_dependent_functions,const bool print_cse_reductions)963       DictionarySubstitutionVisitor<ReturnType, ExpressionType>::print(
964         StreamType &stream,
965         const bool  print_independent_symbols,
966         const bool  print_dependent_functions,
967         const bool  print_cse_reductions) const
968       {
969         if (print_independent_symbols)
970           {
971             stream << "Independent variables: \n";
972             for (unsigned i = 0; i < independent_symbols.size(); ++i)
973               stream << "  " << i << ": " << independent_symbols[i] << "\n";
974 
975             stream << std::flush;
976           }
977 
978         // Check to see if CSE has been performed
979         if (print_cse_reductions && cse.executed())
980           {
981             cse.print(stream);
982           }
983         else
984           {
985             Assert(dependent_functions.size() > 0, ExcInternalError());
986 
987             if (print_dependent_functions)
988               {
989                 stream << "Dependent variables: \n";
990                 for (unsigned i = 0; i < dependent_functions.size(); ++i)
991                   stream << "  " << i << dependent_functions[i] << "\n";
992 
993                 stream << std::flush;
994               }
995           }
996       }
997 
998 #  endif // DOXYGEN
999 
1000     } // namespace internal
1001   }   // namespace SD
1002 } // namespace Differentiation
1003 
1004 
1005 DEAL_II_NAMESPACE_CLOSE
1006 
1007 #endif // DEAL_II_WITH_SYMENGINE
1008 
1009 #endif // dealii_differentiation_sd_symengine_number_visitor_internal_h
1010