1 /*
2  *  Copyright (C) 2004-2021 Edward F. Valeev
3  *
4  *  This file is part of Libint.
5  *
6  *  Libint is free software: you can redistribute it and/or modify
7  *  it under the terms of the GNU General Public License as published by
8  *  the Free Software Foundation, either version 3 of the License, or
9  *  (at your option) any later version.
10  *
11  *  Libint is distributed in the hope that it will be useful,
12  *  but WITHOUT ANY WARRANTY; without even the implied warranty of
13  *  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
14  *  GNU General Public License for more details.
15  *
16  *  You should have received a copy of the GNU General Public License
17  *  along with Libint.  If not, see <http://www.gnu.org/licenses/>.
18  *
19  */
20 
21 #ifndef _libint2_src_bin_libint_dg_h_
22 #define _libint2_src_bin_libint_dg_h_
23 
24 #include <iostream>
25 #include <string>
26 #include <list>
27 #include <map>
28 #include <vector>
29 #include <deque>
30 #include <algorithm>
31 #include <stdexcept>
32 #include <cassert>
33 
34 #include <global_macros.h>
35 #include <exception.h>
36 #include <smart_ptr.h>
37 #include <key.h>
38 #include <dgvertex.h>
39 
40 namespace libint2 {
41 
42 //  class DGVertex;
43   class DGArc;
44   template <class T> class DGArcRel;
45   template <class T> class AlgebraicOperator;
46   class Strategy;
47   class Tactic;
48   class CodeContext;
49   class MemoryManager;
50   class ImplicitDimensions;
51   class CodeSymbols;
52   class GraphRegistry;
53   class InternalGraphRegistry;
54 
55   /** DirectedGraph is an implementation of a directed graph
56       composed of vertices represented by DGVertex objects. Most important operations
57       will assume that this is a DAG, i.e. there are no directed cycles.
58 
59       \note The objects are allocated on free store and the graph is implemented as
60       an object of type 'vertices'.
61    */
62 
63   class DirectedGraph : public EnableSafePtrFromThis<DirectedGraph> {
64   public:
65     typedef DGVertex vertex;
66     typedef DGArc arc;
67     typedef SafePtr<DGVertex> ver_ptr;
68     typedef SafePtr<DGArc> arc_ptr;
69     typedef DGVertexKey key_type;
70     typedef std::multimap<key_type,ver_ptr> VPtrAssociativeContainer;
71     typedef std::list<ver_ptr> VPtrSequenceContainer;
72 
73     typedef VPtrSequenceContainer targets;
74 #if USE_ASSOCCONTAINER_BASED_DIRECTEDGRAPH
75     typedef VPtrAssociativeContainer vertices;
76 #else
77     typedef VPtrSequenceContainer vertices;
78 #endif
79     typedef targets::iterator target_iter;
80     typedef targets::const_iterator target_citer;
81     typedef vertices::iterator ver_iter;
82     typedef vertices::const_iterator ver_citer;
83     //not possible: typedef vertex::Address address;
84     typedef int address;
85     //not possible: typedef vertex::Size size;
86     typedef unsigned int size;
87     typedef std::vector<address> addresses;
88 
89   private:
90     /// converts what is stored in the container to a smart ptr to the vertex
vertex_ptr(const VPtrAssociativeContainer::value_type & v)91     static inline const ver_ptr& vertex_ptr(const VPtrAssociativeContainer::value_type& v) {
92       return v.second;
93     }
vertex_ptr(const VPtrSequenceContainer::value_type & v)94     static inline const ver_ptr& vertex_ptr(const VPtrSequenceContainer::value_type& v) {
95       return v;
96     }
97     /// converts what is stored in the container to a smart ptr to the vertex
vertex_ptr(VPtrAssociativeContainer::value_type & v)98     static inline ver_ptr& vertex_ptr(VPtrAssociativeContainer::value_type& v) {
99       return v.second;
100     }
vertex_ptr(VPtrSequenceContainer::value_type & v)101     static inline ver_ptr& vertex_ptr(VPtrSequenceContainer::value_type& v) {
102       return v;
103     }
104 
105   public:
106 
107     /** Creates an empty DAG. Actual initialization of the graph
108         must be done using append_target */
109     DirectedGraph();
110     ~DirectedGraph();
111 
112     /// Returns the number of vertices
num_vertices()113     unsigned int num_vertices() const { return stack_.size(); }
114 #if 0
115     /// Returns all vertices
116     const vertices& all_vertices() const { return stack_; }
117     /// Returns all targets
118     const targets& all_targets() const { return targets_; }
119 #endif
120     /// Find vertex v or it's equivalent on the graph. Return null pointer if not found.
121     /// Computational cost for a graph based on a nonassociative container may be high
find(const SafePtr<DGVertex> & v)122     const SafePtr<DGVertex>& find(const SafePtr<DGVertex>& v) const { return vertex_is_on(v); }
123 
124     /** appends v to the graph. If v's copy is already on the graph, return the pointer
125 	to the copy. Else return pointer to *v.
126      */
127     SafePtr<DGVertex> append_vertex(const SafePtr<DGVertex>& v);
128 
129     /** non-template append_target appends the vertex to the graph as a target
130      */
131     void append_target(const SafePtr<DGVertex>&);
132 
133     /** append_target appends I to the graph as a target vertex and applies
134         RR to it. append_target can be called multiple times on the same
135         graph if more than one target vertex is needed.
136 
137         I must derive from DGVertex. RR must derive from RecurrenceRelation.
138         RR has a constructor which takes const I& as the only argument.
139         RR must have a public member const I* child(unsigned int) .
140 
141         NOTE TO SELF : need to implement these restrictions using
142         standard Bjarne Stroustrup's approach.
143 
144      */
append_target(const SafePtr<I> & target)145     template <class I, class RR> void append_target(const SafePtr<I>& target) {
146       target->make_a_target();
147       recurse<I,RR>(target);
148     }
149 
150     /** apply_to_all applies RR to all vertices already on the graph.
151 
152         RR must derive from RecurrenceRelation. RR must define TargetType
153         as a typedef.
154         RR must have a public member const DGVertex* child(unsigned int) .
155 
156         NOTE TO SELF : need to implement these restrictions using
157         standard Bjarne Stroustrup's approach.
158 
159      */
apply_to_all()160     template <class RR> void apply_to_all() {
161       typedef typename RR::TargetType TT;
162       typedef vertices::const_iterator citer;
163       typedef vertices::iterator iter;
164       for(iter v=stack_.begin(); v!=stack_.end(); ++v) {
165         ver_ptr& vptr = vertex_ptr(*v);
166         if ((vptr)->num_exit_arcs() != 0)
167           continue;
168         SafePtr<TT> tptr = dynamic_pointer_cast<TT,DGVertex>(v);
169         if (tptr == 0)
170           continue;
171 
172         SafePtr<RR> rr0(new RR(tptr));
173         const int num_children = rr0->num_children();
174 
175         for(int c=0; c<num_children; c++) {
176 
177           SafePtr<DGVertex> child = rr0->child(c);
178           SafePtr<DGArc> arc(new DGArcRel<RR>(tptr,child,rr0));
179           tptr->add_exit_arc(arc);
180 
181           recurse<RR>(child);
182 
183         }
184       }
185     }
186 
187     /** after all append_target's have been called, apply(strategy,tactic)
188       constructs a graph. strategy specifies how to apply recurrence relations.
189       The goal of strategies is to connect the target vertices to simpler, precomputable vertices.
190       There usually are many ways to reduce a vertex.
191       Tactic specifies which of these possibilities to choose.
192      */
193     void apply(const SafePtr<Strategy>& strategy,
194         const SafePtr<Tactic>& tactic);
195 
196     typedef void (DGVertex::* DGVertexMethodPtr)();
197     /** apply_at<method>(vertex) calls method() on vertex and all of its descendants
198      */
199     template <DGVertexMethodPtr method>
apply_at(const SafePtr<DGVertex> & vertex)200     void apply_at(const SafePtr<DGVertex>& vertex) const {
201       ((vertex.get())->*method)();
202       typedef DGVertex::ArcSetType::const_iterator aciter;
203       const aciter abegin = vertex->first_exit_arc();
204       const aciter aend = vertex->plast_exit_arc();
205       for(aciter a=abegin; a!=aend; ++a)
206         apply_at<method>((*a)->dest());
207     }
208 
209     /** calls Method(v) for each v, iterating in forward direction */
210     template <class Method>
foreach(Method & m)211     void foreach(Method& m) {
212       typedef vertices::const_iterator citer;
213       typedef vertices::iterator iter;
214       for(iter v=stack_.begin(); v!=stack_.end(); ++v) {
215         ver_ptr& vptr = vertex_ptr(*v);
216         m(vptr);
217       }
218     }
219 
220     /** calls Method(v) for each v, iterating in forward direction */
221     template <class Method>
foreach(Method & m)222     void foreach(Method& m) const  {
223       typedef vertices::const_iterator citer;
224       typedef vertices::iterator iter;
225       for(citer v=stack_.begin(); v!=stack_.end(); ++v) {
226         const ver_ptr& vptr = vertex_ptr(*v);
227         m(vptr);
228       }
229     }
230     /** calls Method(v) for each v, iterating in reverse direction */
231     template <class Method>
rforeach(Method & m)232     void rforeach(Method& m) {
233       typedef vertices::const_reverse_iterator criter;
234       typedef vertices::reverse_iterator riter;
235       for(riter v=stack_.rbegin(); v!=stack_.rend(); ++v) {
236         ver_ptr& vptr = vertex_ptr(*v);
237         m(vptr);
238       }
239     }
240 
241     /** calls Method(v) for each v, iterating in reverse direction */
242     template <class Method>
rforeach(Method & m)243     void rforeach(Method& m) const {
244       typedef vertices::const_reverse_iterator criter;
245       typedef vertices::reverse_iterator riter;
246       for(criter v=stack_.rbegin(); v!=stack_.rend(); ++v) {
247         const ver_ptr& vptr = vertex_ptr(*v);
248         m(vptr);
249       }
250     }
251 
252     /** after Strategy has been applied, simple recurrence relations need to be
253         optimized away. optimize_rr_out() will replace all simple recurrence relations
254         with code representing them.
255      */
256     void optimize_rr_out(const SafePtr<CodeContext>& context);
257 
258     /** after all apply's have been called, traverse()
259         construct a heuristic order of traversal for the graph.
260      */
261     void traverse();
262 
263     /** update func_names_
264      */
265     void update_func_names();
266 
267     /// Prints out call sequence
268     void debug_print_traversal(std::ostream& os) const;
269 
270     /**
271     Prints out the graph in format understood by program "dot"
272     of package "graphviz". If symbols is true then label vertices
273     using their symbols rather than (descriptive) labels.
274      */
275     void print_to_dot(bool symbols, std::ostream& os = std::cout) const;
276 
277     /**
278        Generates code for the current computation using context.
279        dims specifies the implicit dimensions,
280        args specifies the code symbols for the arguments to the function,
281        label specifies the tag for the computation,
282        decl specifies the stream to receive declaration code,
283        code specifies the stream to receive the definition code
284      */
285     void generate_code(const SafePtr<CodeContext>& context, const SafePtr<MemoryManager>& memman,
286         const SafePtr<ImplicitDimensions>& dims, const SafePtr<CodeSymbols>& args,
287         const std::string& label,
288         std::ostream& decl, std::ostream& code);
289 
290     /** Resets the graph and all vertices. The stack of unresolved recurrence
291         relations is preserved.
292      */
293     void reset();
294 
295     /** num_children_on(rr) returns the number of children of rr which
296         are already on this graph.
297      */
298     template <class RR>
299     unsigned int
num_children_on(const SafePtr<RR> & rr)300     num_children_on(const SafePtr<RR>& rr) const {
301       unsigned int nchildren = rr->num_children();
302       unsigned int nchildren_on_stack = 0;
303       for(unsigned int c=0; c<nchildren; c++) {
304         if (!vertex_is_on(rr->rr_child(c)))
305           continue;
306         else
307           nchildren_on_stack++;
308       }
309 
310       return nchildren_on_stack;
311     }
312 
313     /// Returns the registry
registry()314     SafePtr<GraphRegistry>& registry() { return registry_; }
registry()315     const SafePtr<GraphRegistry>& registry() const { return registry_; }
316 
317     /// return the graph label
label()318     const std::string& label() const { return label_; }
319     /// sets label to \c new_label
set_label(const std::string & new_label)320     void set_label(const std::string& new_label) { label_ = new_label; }
321 
322     /// return true if there are vertices with 0 children but not pre-computed
323     bool missing_prerequisites() const;
324 
325   private:
326 
327     /// contains vertices
328     vertices stack_;
329     /// refers to targets, cannot be an associative container -- order of iteration over targets is important
330     targets targets_;
331     /// addresses of blocks which accumulate targets
332     addresses target_accums_;
333 
334     // graph label, used for annotating internal work, e.g. graphviz plots
335     std::string label_;
336 
337     typedef std::map<std::string,bool> FuncNameContainer;
338     /** Maintains the list of names of functions calls to which have been generated so far.
339         It is used to generate include statements.
340      */
341     FuncNameContainer func_names_;
342 
343 #if !USE_ASSOCCONTAINER_BASED_DIRECTEDGRAPH
344     static const unsigned int default_size_ = 100;
345 #endif
346 
347     // maintains data about the graph which does not belong IN the graph
348     SafePtr<GraphRegistry> registry_;
349     // maintains private data about the graph which does not belong IN the graph
350     SafePtr<InternalGraphRegistry> iregistry_;
351 
352     /// Access the internal registry
iregistry()353     SafePtr<InternalGraphRegistry>& iregistry() { return iregistry_; }
iregistry()354     const SafePtr<InternalGraphRegistry>& iregistry() const { return iregistry_; }
355 
356     /** adds a vertex to the graph. If the vertex already found on the graph
357         then the vertex is not added and the function returns false */
358     SafePtr<DGVertex> add_vertex(const SafePtr<DGVertex>&);
359     /** same as add_vertex(), only assumes that there's no equivalent vertex on the graph (see vertex_is_on) */
360     void add_new_vertex(const SafePtr<DGVertex>&);
361     /// returns true if vertex if already on graph
362     const SafePtr<DGVertex>& vertex_is_on(const SafePtr<DGVertex>& vertex) const;
363     /// removes vertex from the graph. may throw CannotPerformOperation
364     void del_vertex(vertices::iterator&);
365     /** This function is used to implement (recursive) append_target().
366         vertex is appended to the graph and then RR is applied to is.
367      */
recurse(const SafePtr<I> & vertex)368     template <class I, class RR> void recurse(const SafePtr<I>& vertex)  {
369       SafePtr<DGVertex> dgvertex = add_vertex(vertex);
370       if (dgvertex != vertex)
371         return;
372 
373       SafePtr<RR> rr0(new RR(vertex));
374       const int num_children = rr0->num_children();
375 
376       for(int c=0; c<num_children; c++) {
377 
378         SafePtr<DGVertex> child = rr0->child(c);
379         SafePtr<DGArc> arc(new DGArcRel<RR>(vertex,child,rr0));
380         vertex->add_exit_arc(arc);
381 
382         SafePtr<I> child_cast = dynamic_pointer_cast<I,DGVertex>(child);
383         if (child_cast == 0)
384           throw std::runtime_error("DirectedGraph::recurse(const SafePtr<I>& vertex) -- dynamic cast failed, most probably this is a logic error!");
385         recurse<I,RR>(child_cast);
386 
387       }
388     }
389 
390     /** This function is used to implement (recursive) apply_to_all().
391         RR is applied to vertex and all its children.
392      */
recurse(const SafePtr<DGVertex> & vertex)393     template <class RR> void recurse(const SafePtr<DGVertex>& vertex)  {
394       SafePtr<DGVertex> dgvertex = add_vertex(vertex);
395       if (dgvertex != vertex)
396         return;
397 
398       typedef typename RR::TargetType TT;
399       SafePtr<TT> tptr = dynamic_pointer_cast<TT,DGVertex>(vertex);
400       if (tptr == 0)
401         return;
402 
403       SafePtr<RR> rr0(new RR(tptr));
404       const int num_children = rr0->num_children();
405 
406       for(int c=0; c<num_children; c++) {
407 
408         SafePtr<DGVertex> child = rr0->child(c);
409         SafePtr<DGArc> arc(new DGArcRel<RR>(vertex,child,rr0));
410         vertex->add_exit_arc(arc);
411 
412         recurse<RR>(child);
413       }
414     }
415 
416     /** This function is used to implement (recursive) apply().
417       strategy and tactic are applied to vertex and all its children.
418      */
419     void apply_to(const SafePtr<DGVertex>& vertex,
420         const SafePtr<Strategy>& strategy,
421         const SafePtr<Tactic>& tactic);
422     /// This function insert expr of type AlgebraicOperator<DGVertex> into the graph
423     SafePtr<DGVertex> insert_expr_at(const SafePtr<DGVertex>& where, const SafePtr< AlgebraicOperator<DGVertex> >& expr);
424     /// This function replaces RecurrenceRelations with concrete arithemtical expressions
425     void replace_rr_with_expr();
426     /// This function gets rid of trivial math such as multiplication/division by 1.0, etc.
427     void remove_trivial_arithmetics();
428     /** This function gets rid of nodes which are connected
429     to their equivalents (such as (ss|ss) shell quartet can only be connected to (ss|ss) integral)
430      */
431     void handle_trivial_nodes(const SafePtr<CodeContext>& context);
432     /// This functions removes vertices not connected to other vertices
433     void remove_disconnected_vertices();
434     /** Finds (binary) subtrees. The subtrees correspond to a single-line code (no intermediates
435         are used in other expressions)
436      */
437     void find_subtrees();
438     /** Finds (binary) subtrees starting (recursively) at v.
439      */
440     void find_subtrees_from(const SafePtr<DGVertex>& v);
441     /** If v1 and v2 are connected by DGArcDirect and all entry arcs to v1 are of the DGArcDirect type as well,
442         this function will reattach all arcs extering v1 to v2 and remove v1 from the graph altogether.
443 	May throw CannotPerformOperation.
444      */
445     bool remove_vertex_at(const SafePtr<DGVertex>& v1, const SafePtr<DGVertex>& v2);
446 
447     // Which vertex is the first to compute
448     SafePtr<DGVertex> first_to_compute_;
449     // prepare_to_traverse must be called before actual traversal
450     void prepare_to_traverse();
451     // traverse_from(arc) build recurively the traversal order
452     void traverse_from(const SafePtr<DGArc>&);
453     // schedule_computation(vertex) puts vertex first in the computation order
454     void schedule_computation(const SafePtr<DGVertex>&);
455 
456     // Compute addresses on stack assuming that quantities larger than min_size_to_alloc to be allocated on stack
457     void allocate_mem(const SafePtr<MemoryManager>& memman,
458         const SafePtr<ImplicitDimensions>& dims,
459         unsigned int min_size_to_alloc = 1);
460     // Assign symbols to the vertices
461     void assign_symbols(const SafePtr<CodeContext>& context, const SafePtr<ImplicitDimensions>& dims);
462     // If v is an AlgebraicOperator, assign (recursively) symbol to the operator. All other must have been already assigned
463     void assign_oper_symbol(const SafePtr<CodeContext>& context, SafePtr<DGVertex>& v);
464     // Print the code using symbols generated with assign_symbols()
465     void print_def(const SafePtr<CodeContext>& context, std::ostream& os,
466         const SafePtr<ImplicitDimensions>& dims,
467         const SafePtr<CodeSymbols>& args);
468 
469     /** Returns true if cannot enclose the code in a vector loop
470         Possible reason: the traversal path contains a RecurrenceRelation that generates a function call
471         (most do except IntegralSet_to_Integrals; \sa RecurrentRelation::is_simple() )
472         */
473     bool cannot_enclose_in_outer_vloop() const;
474 
475   };
476 
477   //
478   // Nonmember utilities
479   //
480 
481   /// converts what is stored in the container to a smart ptr to the vertex
vertex_ptr(const DirectedGraph::VPtrAssociativeContainer::value_type & v)482   inline const DirectedGraph::ver_ptr& vertex_ptr(const DirectedGraph::VPtrAssociativeContainer::value_type& v) {
483     return v.second;
484   }
vertex_ptr(const DirectedGraph::VPtrSequenceContainer::value_type & v)485   inline const DirectedGraph::ver_ptr& vertex_ptr(const DirectedGraph::VPtrSequenceContainer::value_type& v) {
486     return v;
487   }
488   /// converts what is stored in the container to a smart ptr to the vertex
vertex_ptr(DirectedGraph::VPtrAssociativeContainer::value_type & v)489   inline DirectedGraph::ver_ptr& vertex_ptr(DirectedGraph::VPtrAssociativeContainer::value_type& v) {
490     return v.second;
491   }
vertex_ptr(DirectedGraph::VPtrSequenceContainer::value_type & v)492   inline DirectedGraph::ver_ptr& vertex_ptr(DirectedGraph::VPtrSequenceContainer::value_type& v) {
493     return v;
494   }
495 
496 #if USE_ASSOCCONTAINER_BASED_DIRECTEDGRAPH
497   inline DirectedGraph::key_type key(const DGVertex& v);
498 #endif
499 
500   //
501   // Nonmember predicates
502   //
503 
504   /// return true if there are non-unrolled targets
505   bool nonunrolled_targets(const DirectedGraph::targets& targets);
506 
507   /// extracts external symbols and RRs from the graph
508   void extract_symbols(const SafePtr<DirectedGraph>& dg);
509 
510   // use these functors with DirectedGraph::foreach
511   struct PrerequisitesExtractor {
512     std::deque< SafePtr<DGVertex> > vertices;
513     void operator()(const SafePtr<DGVertex>& v);
514   };
515   struct VertexPrinter {
VertexPrinterVertexPrinter516     VertexPrinter(std::ostream& ostr) : os(ostr) {}
517     std::ostream& os;
518     void operator()(const SafePtr<DGVertex>& v);
519   };
520 
521 };
522 
523 
524 #endif
525