1 /*
2  *  Copyright (C) 2004-2021 Edward F. Valeev
3  *
4  *  This file is part of Libint.
5  *
6  *  Libint is free software: you can redistribute it and/or modify
7  *  it under the terms of the GNU General Public License as published by
8  *  the Free Software Foundation, either version 3 of the License, or
9  *  (at your option) any later version.
10  *
11  *  Libint is distributed in the hope that it will be useful,
12  *  but WITHOUT ANY WARRANTY; without even the implied warranty of
13  *  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
14  *  GNU General Public License for more details.
15  *
16  *  You should have received a copy of the GNU General Public License
17  *  along with Libint.  If not, see <http://www.gnu.org/licenses/>.
18  *
19  */
20 
21 #include <cstdio>
22 #include <functional>
23 #include <utility>
24 #include <fstream>
25 #include <dg.h>
26 #include <rr.h>
27 #include <strategy.h>
28 #include <prefactors.h>
29 #include <codeblock.h>
30 #include <default_params.h>
31 #include <graph_registry.h>
32 #include <global_macros.h>
33 #include <extract.h>
34 #include <algebra.h>
35 #include <task.h>
36 #include <context.h>
37 #include <intset_to_ints.h>
38 #include <uncontract.h>
39 #include <dims.h>
40 
41 using namespace std;
42 using namespace libint2;
43 
44 #define ONLY_CLONE_IF_DIFF 1
45 
46 //// utils first
47 
48 namespace {
push(DirectedGraph::VPtrAssociativeContainer & vertices,const DirectedGraph::ver_ptr & v)49   void push(DirectedGraph::VPtrAssociativeContainer& vertices, const DirectedGraph::ver_ptr& v) {
50     DirectedGraph::key_type vkey = libint2::key(*v);
51     vertices.insert(std::make_pair(vkey,v));
52   }
push(DirectedGraph::VPtrSequenceContainer & vertices,const DirectedGraph::ver_ptr & v)53   void push(DirectedGraph::VPtrSequenceContainer& vertices, const DirectedGraph::ver_ptr& v) {
54     vertices.push_back(v);
55   }
56 }
57 
58 
DirectedGraph()59 DirectedGraph::DirectedGraph() :
60   stack_(), targets_(), target_accums_(), label_("graph"), func_names_(),
61   registry_(SafePtr<GraphRegistry>(new GraphRegistry)),
62   iregistry_(SafePtr<InternalGraphRegistry>(new InternalGraphRegistry)),
63   first_to_compute_()
64 {
65   stack_.clear();
66   targets_.clear();
67 }
68 
~DirectedGraph()69 DirectedGraph::~DirectedGraph()
70 {
71   reset();
72 }
73 
74 void
append_target(const SafePtr<DGVertex> & target)75 DirectedGraph::append_target(const SafePtr<DGVertex>& target)
76 {
77   target->make_a_target();
78   append_vertex(target);
79   push(targets_,target);
80 }
81 
82 SafePtr<DGVertex>
append_vertex(const SafePtr<DGVertex> & vertex)83 DirectedGraph::append_vertex(const SafePtr<DGVertex>& vertex)
84 {
85   // If this vertex is owned by this graph, return it immediately
86   if (vertex->dg() == this) {
87 #if DEBUG
88     std::cout << "append_vertex: vertex " << vertex->label() << " is already on" << endl;
89 #endif
90     return vertex;
91   }
92   auto vcopy_on_graph = add_vertex(vertex);
93   // If this is a new vertex -- tell the vertex who its owner is now
94   if (vcopy_on_graph == vertex)
95     vertex->dg(this);
96   return vcopy_on_graph;
97 }
98 
99 SafePtr<DGVertex>
add_vertex(const SafePtr<DGVertex> & vertex)100 DirectedGraph::add_vertex(const SafePtr<DGVertex>& vertex)
101 {
102   // if vertex is precomputed -- can replicate it without consequences
103   // else check if it's on already
104   if (!vertex->precomputed()) {
105     SafePtr<DGVertex> vcopy_on_graph = vertex_is_on(vertex);
106     if (vcopy_on_graph)
107       return vcopy_on_graph;
108   }
109   add_new_vertex(vertex);
110   return vertex;
111 }
112 
113 void
add_new_vertex(const SafePtr<DGVertex> & vertex)114 DirectedGraph::add_new_vertex(const SafePtr<DGVertex>& vertex)
115 {
116 #if 0
117   // Resize if using std::vector
118 #if !USE_ASSOCCONTAINER_BASED_DIRECTEDGRAPH
119   if (num_vertices() == stack_.capacity()) {
120     stack_.resize( stack_.capacity() + default_size_ );
121 #if DEBUG
122     cout << "Increased size of DirectedGraph's stack to "
123          << stack_.size() << endl;
124 #endif
125   }
126 #endif
127 #endif
128 
129   char label[80];  sprintf(label,"vertex%d",num_vertices());
130   vertex->set_graph_label(label);
131   vertex->dg(this);
132 
133   push(stack_,vertex);
134 #if DEBUG
135   cout << "add_new_vertex: added vertex " << vertex->description() << endl;
136 #endif
137 
138   return;
139 }
140 
141 const SafePtr<DGVertex>&
vertex_is_on(const SafePtr<DGVertex> & vertex) const142 DirectedGraph::vertex_is_on(const SafePtr<DGVertex>& vertex) const
143 {
144   if (vertex->dg() == this)
145     return vertex;
146 
147   static SafePtr<DGVertex> null_ptr;
148 #if USE_ASSOCCONTAINER_BASED_DIRECTEDGRAPH
149   typedef vertices::const_iterator citer;
150   typedef vertices::value_type value_type;
151   key_type vkey = key(*vertex);
152   // find the first elemnt with this key and iterate until vertex is found or key changes
153   citer vpos = stack_.find(vkey);
154   const citer end = stack_.end();
155   if (vpos != end) {
156     bool can_find = true;
157     while(can_find) {
158       if (can_find && (vpos->second)->equiv(vertex)) {
159 #if DEBUG
160 	std::cout << "vertex_is_on: " << (vpos->second)->label() << std::endl;
161 #endif
162 	return vpos->second;
163       }
164       ++vpos;
165       can_find = (vpos->first == vkey) && (vpos != end);
166     }
167   }
168 #else
169   typedef vertices::const_reverse_iterator criter;
170   typedef vertices::reverse_iterator riter;
171   const criter rend = stack_.rend();
172   for(criter v=stack_.rbegin(); v!=rend; ++v) {
173     const ver_ptr& vptr = vertex_ptr(*v);
174     if(vertex->equiv(vptr)) {
175 #if DEBUG
176       std::cout << "vertex_is_on: " << (vptr)->label() << std::endl;
177 #endif
178       return vptr;
179     }
180   }
181 #endif
182 #if DEBUG
183   std::cout << "vertex_is_on: NOT " << (vertex)->label() << std::endl;
184 #endif
185   return null_ptr;
186 }
187 
188 namespace {
189   struct __reset_dgvertex {
operator ()__anonad020a590211::__reset_dgvertex190     void operator()(SafePtr<DGVertex>& v) {
191       v->reset();
192 #if DEBUG
193       std::cout << "DirectedGraph::reset: will unregister " << v->label() << std::endl;
194 #endif
195       // remove this vertex from its SingletonManager
196       v->unregister();
197     }
198   };
199   struct __reset_safeptr {
operator ()__anonad020a590211::__reset_safeptr200     void operator()(SafePtr<DGVertex>& v) {
201       v.reset();
202     }
203   };
204 }
205 
206 void
del_vertex(vertices::iterator & v)207 DirectedGraph::del_vertex(vertices::iterator& v)
208 {
209   static __reset_dgvertex rv;
210   if (v == stack_.end())
211     throw CannotPerformOperation("DirectedGraph::del_vertex() cannot delete vertex");
212   ver_ptr& vptr = vertex_ptr(*v);
213   // Cannot delete targets. Should I be able to? Probably not
214   if (vptr->is_a_target())
215     throw CannotPerformOperation("DirectedGraph::del_vertex() cannot delete targets");
216   if (vptr->num_exit_arcs() == 0 && vptr->num_entry_arcs() == 0) {
217 #if DEBUG
218     std::cout << "del_vertex: trying to remove " << (vertex_ptr(*v))->label() << std::endl;
219 #endif
220     SafePtr<DGVertex> vptr = vertex_ptr(*v); // keep an instance of the pointer to avoid accidental automatic destruction of the DGVertex object
221     stack_.erase(v);
222     rv(vptr);
223 #if DEBUG
224     std::cout << "del_vertex: successful vertex removal " << std::endl;
225 #endif
226   }
227   else
228     throw CannotPerformOperation("DirectedGraph::del_vertex() cannot delete vertex");
229 }
230 
231 namespace{
232   struct __prepare_to_traverse {
operator ()__anonad020a590311::__prepare_to_traverse233     void operator()(SafePtr<DGVertex>& v) {
234       v->prepare_to_traverse();
235     }
236   };
237 
sort_children_by_nparents(DGVertex::ArcSetType::const_iterator begin,DGVertex::ArcSetType::const_iterator end)238   std::vector<DGVertex::ArcSetType::value_type> sort_children_by_nparents(DGVertex::ArcSetType::const_iterator begin,
239                                  DGVertex::ArcSetType::const_iterator end) {
240     // std::sort works only for containers that support random access
241 //    std::vector<DGVertex::ArcSetType::value_type> sorted_children;
242 //    for(DGVertex::ArcSetType::const_iterator i=begin; i!=end; ++i)
243 //      sorted_children.push_back(*i);
244     std::vector<DGVertex::ArcSetType::value_type> sorted_children(begin, end);
245     std::sort(sorted_children.begin(), sorted_children.end(),
246               [](DGVertex::ArcSetType::value_type a,
247                  DGVertex::ArcSetType::value_type b) {
248           return a->dest()->num_entry_arcs() < b->dest()->num_entry_arcs();
249         }
250       );
251     return sorted_children;
252   }
253 }
254 
255 void
prepare_to_traverse()256 DirectedGraph::prepare_to_traverse()
257 {
258   __prepare_to_traverse __ptt;
259   foreach(__ptt);
260 }
261 
262 /**
263  * Recursively traverse depth-first, once a node has been tagged by all of its parents schedule its computation
264  */
265 void
traverse()266 DirectedGraph::traverse()
267 {
268   // Initialization
269   prepare_to_traverse();
270 
271   // Start at the targets which don't have parents
272   typedef vertices::const_iterator citer;
273   typedef vertices::iterator iter;
274   for(iter v=stack_.begin(); v!=stack_.end(); ++v) {
275     const ver_ptr& vptr = vertex_ptr(*v);
276     if ((vptr)->is_a_target() && (vptr)->num_entry_arcs() == 0) {
277       // First, since this target doesn't have parents we can schedule its computation
278       schedule_computation(vptr);
279 
280       //
281       // traverse the rest of graph starting from each child
282       // traversal will start with children with the fewest parents
283       // an explanation for this heuristic: I want to compute the children with 1 parent after
284       // other children so that I can potentially fold their computation with
285       // addition/subtraction to produce FMA and other composite instructions
286       //
287       {
288         // std::sort works only fro containers that support random access
289         std::vector<DGVertex::ArcSetType::value_type> sorted_children = sort_children_by_nparents(vptr->first_exit_arc(),
290                                                                                                   vptr->plast_exit_arc());
291         for(auto a=sorted_children.begin(); a!=sorted_children.end(); ++a) {
292           traverse_from(*a);
293         }
294       }
295 
296     }
297   }
298 }
299 
300 void
traverse_from(const SafePtr<DGArc> & arc)301 DirectedGraph::traverse_from(const SafePtr<DGArc>& arc)
302 {
303   SafePtr<DGVertex> orig = arc->orig();
304   SafePtr<DGVertex> dest = arc->dest();
305 #if DEBUG_TRAVERSAL
306   std::cout << "traverse_from: orig = " << orig << " dest = " << dest << endl;
307 #endif
308   // no need to compute if precomputed OR has no children
309   if (dest->precomputed() || dest->num_exit_arcs() == 0) {
310 #if DEBUG_TRAVERSAL
311     std::cout << "traverse from: dest is precomputed" << std::endl;
312 #endif
313     return;
314   }
315   // if has been hit by all parents ...
316   const unsigned int num_tags = dest->tag();
317 #if DEBUG_TRAVERSAL
318   std::cout << "traverse from: tagged dest, ntags = " << num_tags << std::endl;
319 #endif
320   const unsigned int num_parents = dest->num_entry_arcs();
321   if (num_tags == num_parents) {
322 
323     // ... check if it is on a subtree ...
324     if (SafePtr<DRTree> stree = dest->subtree()) {
325       // ... if yes, schedule only if it is a root of a subtree
326       if (stree->root() == dest)
327         schedule_computation(dest);
328     }
329     // else schedule
330     else
331       schedule_computation(dest);
332 
333     {
334       // std::sort works only fro containers that support random access
335       std::vector<DGVertex::ArcSetType::value_type> sorted_children = sort_children_by_nparents(dest->first_exit_arc(),
336                                                                                                 dest->plast_exit_arc());
337       for(auto a=sorted_children.begin(); a!=sorted_children.end(); ++a) {
338         traverse_from(*a);
339       }
340     }
341 
342   }
343 }
344 
345 void
schedule_computation(const SafePtr<DGVertex> & vertex)346 DirectedGraph::schedule_computation(const SafePtr<DGVertex>& vertex)
347 {
348   vertex->set_postcalc(first_to_compute_);
349   first_to_compute_ = vertex;
350 #if DEBUG || DEBUG_TRAVERSAL
351   std::cout << "schedule_computation: " << vertex << endl;
352   vertex->print(std::cout);
353 #endif
354 }
355 
356 
357 void
debug_print_traversal(std::ostream & os) const358 DirectedGraph::debug_print_traversal(std::ostream& os) const
359 {
360   SafePtr<DGVertex> current_vertex = first_to_compute_;
361 
362   os << "Debug print of traversal order" << endl;
363 
364   do {
365     current_vertex->print(os);
366     current_vertex = current_vertex->postcalc();
367   } while (current_vertex != 0);
368 }
369 
370 namespace {
371 
372   struct __print_vertices_to_dot {
373     bool symbols;
374     std::ostream& os;
__print_vertices_to_dot__anonad020a590511::__print_vertices_to_dot375     __print_vertices_to_dot(bool s, std::ostream& o) : symbols(s), os(o) {}
operator ()__anonad020a590511::__print_vertices_to_dot376     void operator()(const SafePtr<DGVertex>& v) {
377       os << "  " << v->graph_label()
378 	 << " [ label = \"";
379       if (symbols && v->symbol_set())
380 	os << v->symbol();
381       else
382 	os << v->label();
383       os << "\"]" << endl;
384     }
385   };
386 
387   struct __print_arcs_to_dot {
388     std::ostream& os;
__print_arcs_to_dot__anonad020a590511::__print_arcs_to_dot389     __print_arcs_to_dot(std::ostream& o) : os(o) {}
operator ()__anonad020a590511::__print_arcs_to_dot390     void operator()(const SafePtr<DGVertex>& v) {
391       typedef DGVertex::ArcSetType::const_iterator aciter;
392       const aciter abegin = v->first_exit_arc();
393       const aciter aend = v->plast_exit_arc();
394       for(aciter a=abegin; a!=aend; ++a) {
395 	SafePtr<DGVertex> dest = (*a)->dest();
396 	os << "  " << v->graph_label() << " -> "
397 	   << dest->graph_label() << endl;
398       }
399     }
400   };
401 
402 }
403 
404 void
print_to_dot(bool symbols,std::ostream & os) const405 DirectedGraph::print_to_dot(bool symbols, std::ostream& os) const
406 {
407   os << "digraph G {" << endl
408      << "  size = \"8,8\"" << endl;
409 
410   __print_vertices_to_dot pvtd(symbols,os);
411   foreach(pvtd);
412 
413   __print_arcs_to_dot patd(os);
414   foreach(patd);
415 
416   // Print traversal order using dotted lines
417   SafePtr<DGVertex> current_vertex = first_to_compute_;
418   if (current_vertex != 0) {
419     do {
420       SafePtr<DGVertex> next = current_vertex->postcalc();
421       if (current_vertex && next) {
422         os << "  " << current_vertex->graph_label() << " -> "
423            << next->graph_label() << " [ style = dotted constraint = false ]";
424       }
425       current_vertex = next;
426     } while (current_vertex != 0);
427   }
428 
429   os << endl << "}" << endl;
430 }
431 
432 void
reset()433 DirectedGraph::reset()
434 {
435   // Reset each vertex, releasing all arcs
436   __reset_dgvertex rv;
437   foreach(rv);
438   __reset_safeptr rptr;
439   foreach(rptr);
440 
441   // if everything went OK then empty out stack_ and targets_
442   stack_.clear();
443   targets_.clear();
444   first_to_compute_.reset();
445   func_names_.clear();
446 }
447 
448 
449 /// Apply a strategy to all vertices not yet computed (i.e. which do not have exit arcs)
450 void
apply(const SafePtr<Strategy> & strategy,const SafePtr<Tactic> & tactic)451 DirectedGraph::apply(const SafePtr<Strategy>& strategy,
452                      const SafePtr<Tactic>& tactic)
453 {
454   const SafePtr<DirectedGraph> this_ptr = SafePtr_from_this();
455   typedef vertices::const_iterator citer;
456   typedef vertices::iterator iter;
457   for(iter v=stack_.begin(); v!=stack_.end(); ++v) {
458     const ver_ptr& vptr = vertex_ptr(*v);
459     if ((vptr)->num_exit_arcs() != 0 || (vptr)->precomputed() || !(vptr)->need_to_compute())
460       continue;
461 
462     SafePtr<RecurrenceRelation> rr0 = strategy->optimal_rr(this_ptr,(vptr),tactic);
463     if (rr0 == 0)
464       continue;
465 
466     // add children to the graph
467     SafePtr<DGVertex> target = rr0->rr_target();
468     const int num_children = rr0->num_children();
469     for(int c=0; c<num_children; c++) {
470       SafePtr<DGVertex> child = rr0->rr_child(c);
471       bool new_vertex = true;
472       SafePtr<DGVertex> dgchild = append_vertex(child);
473       if (dgchild != child) {
474 	child = dgchild;
475 	new_vertex = false;
476       }
477       SafePtr<DGArc> arc(new DGArcRel<RecurrenceRelation>(target,child,rr0));
478       target->add_exit_arc(arc);
479       if (new_vertex)
480         apply_to(child,strategy,tactic);
481     }
482   }
483 }
484 
485 /// Add vertex to graph and apply a strategy to vertex recursively
486 void
apply_to(const SafePtr<DGVertex> & vertex,const SafePtr<Strategy> & strategy,const SafePtr<Tactic> & tactic)487 DirectedGraph::apply_to(const SafePtr<DGVertex>& vertex,
488                         const SafePtr<Strategy>& strategy,
489                         const SafePtr<Tactic>& tactic)
490 {
491   bool not_yet_computed = !vertex->precomputed() && vertex->need_to_compute() && (vertex->num_exit_arcs() == 0);
492   if (!not_yet_computed)
493     return;
494   SafePtr<RecurrenceRelation> rr0 = strategy->optimal_rr(SafePtr_from_this(),vertex,tactic);
495   if (rr0 == 0)
496     return;
497 
498   SafePtr<DGVertex> target = rr0->rr_target();
499   const int num_children = rr0->num_children();
500   for(int c=0; c<num_children; c++) {
501     SafePtr<DGVertex> child = rr0->rr_child(c);
502     bool new_vertex = true;
503     SafePtr<DGVertex> dgchild = append_vertex(child);
504     if (dgchild != child) {
505       child = dgchild;
506       new_vertex = false;
507     }
508     SafePtr<DGArc> arc(new DGArcRel<RecurrenceRelation>(target,child,rr0));
509     try {
510       target->add_exit_arc(arc);
511     }
512     catch (...) {
513       std::cout << "failed to use RR " << rr0->label() << std::endl;
514       throw;
515     }
516     if (new_vertex)
517       apply_to(child,strategy,tactic);
518   }
519 }
520 
521 // Optimize out simple recurrence relations
522 void
optimize_rr_out(const SafePtr<CodeContext> & context)523 DirectedGraph::optimize_rr_out(const SafePtr<CodeContext>& context)
524 {
525   replace_rr_with_expr();
526   remove_trivial_arithmetics();
527   handle_trivial_nodes(context);
528   remove_disconnected_vertices();
529   find_subtrees();
530 }
531 
532 // Replace recurrence relations with expressions
533 void
replace_rr_with_expr()534 DirectedGraph::replace_rr_with_expr()
535 {
536   typedef vertices::const_iterator citer;
537   typedef vertices::iterator iter;
538   for(iter v=stack_.begin(); v!=stack_.end(); ++v) {
539     const ver_ptr& vptr = vertex_ptr(*v);
540     if ((vptr)->num_exit_arcs()) {
541       SafePtr<DGArc> arc0 = *((vptr)->first_exit_arc());
542       SafePtr<DGArcRR> arc0_cast = dynamic_pointer_cast<DGArcRR,DGArc>(arc0);
543       if (arc0_cast == 0)
544         continue;
545       SafePtr<RecurrenceRelation> rr = arc0_cast->rr();
546 
547       // Optimize if the recurrence relation is simple and the target and
548       // children are of the same type
549       if (rr->is_simple() && rr->invariant_type()) {
550 
551 #if DEBUG || DEBUG_RESTRUCTURE
552 	std::cout << "replace_rr_with_expr: replacing " << rr->label() << endl;
553 	std::cout << "replace_rr_with_expr:      with " << rr->rr_expr()->description() << endl;
554 	std::cout << "replace_rr_with_expr: nchildren = " << rr->num_children() << endl;
555 	for(unsigned int c=0; c<rr->num_children(); ++c) {
556 	  std::cout << "replace_rr_with_expr: child " << c << " " << rr->rr_child(c)->description() << endl;
557 	}
558 #endif
559 
560         // Remove arcs connecting this vertex to children
561         (vptr)->del_exit_arcs();
562 
563         // and instead insert the numerical expression
564         SafePtr<RecurrenceRelation::ExprType> rr_expr = rr->rr_expr();
565         SafePtr<DGVertex> expr_vertex = static_pointer_cast<RecurrenceRelation::ExprType,DGVertex>(rr_expr);
566         expr_vertex = insert_expr_at((vptr),rr_expr);
567         SafePtr<DGArc> arc(new DGArcDirect((vptr),expr_vertex));
568         (vptr)->add_exit_arc(arc);
569 
570       }
571     }
572   }
573 }
574 
575 
576 //
577 // This function is very tricky at the moment. The operands have to be added before the operator
578 // such that operands are guaranteed to be on graph before the operator. This way ExprType::equiv
579 // can simply compare operand pointers and the operator type (very cheap operations compared
580 // to the fully recursive explicit comparison).
581 //
582 SafePtr<DGVertex>
insert_expr_at(const SafePtr<DGVertex> & where,const SafePtr<RecurrenceRelation::ExprType> & expr)583 DirectedGraph::insert_expr_at(const SafePtr<DGVertex>& where, const SafePtr<RecurrenceRelation::ExprType>& expr)
584 {
585 #if DEBUG
586   cout << "insert_expr_at: " << expr->description() << endl;
587 #endif
588 
589   typedef RecurrenceRelation::ExprType ExprType;
590   SafePtr<DGVertex> expr_vertex = static_pointer_cast<DGVertex,ExprType>(expr);
591 
592   // If the expression is already on then return it
593   if (expr->dg() == this)
594     return expr_vertex;
595 
596   SafePtr<DGVertex> left_oper = expr->left();
597   SafePtr<DGVertex> right_oper = expr->right();
598   bool new_left =  (left_oper->dg()  != this);       //
599   bool new_right = (right_oper->dg() != this);       //
600   bool need_to_clone = false;  // clone expression if (parts of) the expression was found on the graph
601 
602   // See if left operand is also an operator
603   const SafePtr<ExprType> left_cast = dynamic_pointer_cast<ExprType,DGVertex>(left_oper);
604   // if yes -- add it to the graph recursively
605   if (left_cast) {
606     left_oper = insert_expr_at(expr_vertex,left_cast);
607   }
608   // else add it directly
609   else {
610     left_oper = append_vertex(left_oper);
611   }
612 #if ONLY_CLONE_IF_DIFF
613   if (left_oper != expr->left()) {
614 #if DEBUG
615     std::cout << "insert_expr_at: append(left) != left" << std::endl;
616 #endif
617     need_to_clone = true;
618     new_left = false;
619   }
620 #else
621 #error "ONLY_CLONE_IF_DIFF must be true"
622 #endif
623 
624   // See if right operand is also an operator
625   SafePtr<ExprType> right_cast = dynamic_pointer_cast<ExprType,DGVertex>(right_oper);
626   // if yes -- add it to the graph recursively
627   if (right_cast) {
628     right_oper = insert_expr_at(expr_vertex,right_cast);
629   }
630   // else add it directly
631   else {
632     right_oper = append_vertex(right_oper);
633   }
634 #if ONLY_CLONE_IF_DIFF
635   if (right_oper != expr->right()) {
636 #if DEBUG
637     std::cout << "insert_expr_at: append(right) != right" << std::endl;
638 #endif
639     need_to_clone = true;
640     new_right = false;
641   }
642 #else
643 #error "ONLY_CLONE_IF_DIFF must be true"
644 #endif
645 
646   if (need_to_clone) {
647     SafePtr<ExprType> expr_new(new ExprType(expr,left_oper,right_oper));
648     expr_vertex = static_pointer_cast<DGVertex,ExprType>(expr_new);
649 #if DEBUG
650     int nchildren = expr->num_exit_arcs();
651     cout << "insert_expr_at: cloned AlgebraicOperator with " << expr->num_exit_arcs() << " children" << endl;
652     if (nchildren) {
653       cout << "Left:  " << expr->left()->description() << endl;
654       cout << "Right: " << expr->right()->description() << endl;
655     }
656 #endif
657   }
658 
659   SafePtr<DGVertex> dgexpr_vertex = expr_vertex;
660   if (new_left || new_right)
661     add_new_vertex(expr_vertex);
662   else {
663     const bool do_cse = registry()->do_cse();
664     if (do_cse) {
665 #if DEBUG
666       std::cout << "insert_expr_at: appending vertex " << expr_vertex->description() << std::endl;
667 #endif
668       dgexpr_vertex = append_vertex(expr_vertex);
669     }
670     else {
671       add_new_vertex(expr_vertex);
672     }
673     if (expr_vertex != dgexpr_vertex) {
674       if (new_left || new_right) {
675         cout << "Problem detected: AlgebraicOperator is found on the stack but one of its operands was new" << endl;
676         cout << expr_vertex->description() << endl;
677         cout << dgexpr_vertex->description() << endl;
678         throw std::runtime_error("DirectedGraph::insert_expr_at() -- vertex is not new but one of the operands is");
679       }
680     }
681     expr_vertex = dgexpr_vertex;
682   }
683   SafePtr<DGArc> left_arc(new DGArcDirect(expr_vertex,left_oper));
684   expr_vertex->add_exit_arc(left_arc);
685   SafePtr<DGArc> right_arc(new DGArcDirect(expr_vertex,right_oper));
686   expr_vertex->add_exit_arc(right_arc);
687 #if DEBUG
688   cout << "insert_expr_at: added arc between " << where->description() << " and " << expr_vertex->description() << endl;
689 #endif
690 
691   return expr_vertex;
692 }
693 
694 // Replace recurrence relations with expressions
695 void
remove_trivial_arithmetics()696 DirectedGraph::remove_trivial_arithmetics()
697 {
698   using libint2::prefactor::Scalar;
699   const SafePtr< CTimeEntity<double> > const_one_point_zero = Scalar(1.0);
700   const SafePtr< CTimeEntity<double> > const_zero_point_zero = Scalar(0.0);
701   typedef vertices::const_iterator citer;
702   typedef vertices::iterator iter;
703   for(iter v=stack_.begin(); v!=stack_.end(); ++v) {
704     const ver_ptr& vptr = vertex_ptr(*v);
705     SafePtr< AlgebraicOperator<DGVertex> > oper_cast = dynamic_pointer_cast<AlgebraicOperator<DGVertex>,DGVertex>((vptr));
706     if (oper_cast) {
707 
708       //std::cout << "cast " << vptr->description() << " to " << oper_cast->description() << std::endl;
709       typedef DGVertex::ArcSetType::const_iterator aciter;
710       aciter a = oper_cast->first_exit_arc();
711       SafePtr<DGVertex> left = (*a)->dest();  ++a;
712       SafePtr<DGVertex> right = oper_cast->num_exit_arcs()>1 ? (*a)->dest() : left; // num_exit_arcs==1 is a corner case, e.g. 1*1
713 
714       using libint2::algebra::OperatorTypes;
715 
716       // 1.0 * x = x || 0.0 + x = x
717       if (left->num_entry_arcs() == 1 &&
718           ((oper_cast->type() == OperatorTypes::Times && left->equiv(const_one_point_zero)) ||
719            (oper_cast->type() == OperatorTypes::Plus && left->equiv(const_zero_point_zero))    )
720          ) {
721 #if DEBUG
722         const bool success = remove_vertex_at((vptr),right);
723         if (success)
724           cout << "Removed vertex " << (vptr)->description() << endl;
725 #else
726         remove_vertex_at((vptr),right);
727 #endif
728       }
729       // x * 1.0 = x || x + 0.0 = x
730       else
731       if (right->num_entry_arcs() == 1 &&
732           ((oper_cast->type() == OperatorTypes::Times && right->equiv(const_one_point_zero)) ||
733            (oper_cast->type() == OperatorTypes::Plus && right->equiv(const_zero_point_zero))    )
734          ) {
735 #if DEBUG
736         const bool success = remove_vertex_at((vptr),left);
737         if (success)
738           cout << "Removed vertex " << (vptr)->description() << endl;
739 #else
740         remove_vertex_at((vptr),left);
741 #endif
742       }
743 
744       // NOTE : more cases to come
745     }
746   }
747 }
748 
749 namespace {
stack_symbol(const SafePtr<CodeContext> & ctext,const DGVertex::Address & address,const DGVertex::Size & size,const std::string & low_rank,const std::string & veclen,const std::string & prefix)750   std::string stack_symbol(const SafePtr<CodeContext>& ctext, const DGVertex::Address& address, const DGVertex::Size& size,
751                            const std::string& low_rank, const std::string& veclen,
752                            const std::string& prefix)
753   {
754     ostringstream oss;
755     std::string stack_address = ctext->stack_address(address);
756     oss << prefix << "[((hsi*" << size << "+"
757         << stack_address << ")*" << low_rank << "+lsi)*"
758         << veclen << "]";
759     return oss.str();
760   }
761 
762   /// Returns a "vector" form of stack symbol, e.g. converts libint->stack[x] to libint->stack[x+vi]
to_vector_symbol(const SafePtr<DGVertex> & v)763   inline std::string to_vector_symbol(const SafePtr<DGVertex>& v)
764   {
765     std::string::size_type current_pos = 0;
766     std::string symb = v->symbol();
767     // replace repeatedly until the string is exhausted
768     while(current_pos != std::string::npos) {
769 
770       // find "[" first
771       const std::string left_braket("[");
772       std::string::size_type where = symb.find(left_braket,current_pos);
773       current_pos = where;
774       // if the prefix indicating a stack symbol found:
775       // 1) make sure vi doesn't appear between the brakets
776       // 2) replace "]" with "+vi]"
777       if (where != std::string::npos) {
778         const std::string right_braket("]");
779         std::string::size_type where = symb.find(right_braket,current_pos);
780         if (where == std::string::npos)
781           throw logic_error("to_vector_symbol() -- address is set but no right braket found");
782 
783         const std::string forbidden("vi");
784         std::string::size_type pos = symb.find(forbidden,current_pos);
785         if (pos == std::string::npos || pos > where) {
786           const std::string what_to_add("+vi");
787           symb.insert(where,what_to_add);
788           current_pos = where + 4;
789         }
790         else {
791           current_pos = where + 1;
792         }
793       }
794     } // end of while
795     return symb;
796   }
797 };
798 
799 //
800 // Handles "trivial" nodes. A node is trivial is it satisfies the following conditions:
801 // 0) not a target
802 // 1) has only one child
803 // 2) the exit arc is of a trivial type (DGArcDirect or IntegralSet_to_Integral applied to node of size 1)
804 //
805 // By "handling" I mean either removing the node from the graph or making a node refer to another node so that
806 // no code is generated for it.
807 //
808 void
handle_trivial_nodes(const SafePtr<CodeContext> & context)809 DirectedGraph::handle_trivial_nodes(const SafePtr<CodeContext>& context)
810 {
811   typedef vertices::const_iterator citer;
812   typedef vertices::iterator iter;
813   for(iter v=stack_.begin(); v!=stack_.end(); ++v) {
814     const ver_ptr& vptr = vertex_ptr(*v);
815     // or if has more than 1 child
816     if ((vptr)->num_exit_arcs() != 1)
817       continue;
818     SafePtr<DGArc> arc = *((vptr)->first_exit_arc());
819 
820     // Is the exit arc DGArcRel<IntegralSet_to_Integrals> and (vptr)->size() == 1?
821     if ((vptr)->size() == 1) {
822       SafePtr<DGArcRR> arc_cast = dynamic_pointer_cast<DGArcRR,DGArc>(arc);
823       if (arc_cast) {
824         SafePtr<RecurrenceRelation> rr = arc_cast->rr();
825         SafePtr<IntegralSet_to_Integrals_base> rr_cast = dynamic_pointer_cast<IntegralSet_to_Integrals_base,RecurrenceRelation>(rr);
826         if (rr_cast) {
827           SafePtr<DGVertex> child = arc->dest();
828 
829           //if (child->symbol_set() == false)
830           {
831             const std::string stack_name("stack");
832             const SafePtr<ImplicitDimensions>& dims = ImplicitDimensions::default_dims();
833             std::string low_rank = dims->low_label();
834             std::string veclen = dims->vecdim_label();
835 
836             if ((vptr)->address_set()) {
837               child->set_symbol(stack_symbol(context,(vptr)->address()+0,(vptr)->size(),low_rank,veclen,stack_name));
838             }
839             if ((vptr)->symbol_set()) {
840               child->set_symbol(stack_symbol(context,0,(vptr)->size(),low_rank,veclen,(vptr)->symbol()));
841             }
842           }
843 
844           vptr->refer_this_to(child);
845           vptr->reset_symbol();
846 
847         }
848       }
849     }
850 
851 
852     // Is the exit arc DGArcDirect?
853     {
854       SafePtr<DGArcDirect> arc_cast = dynamic_pointer_cast<DGArcDirect,DGArc>(arc);
855       if (arc_cast) {
856         // remove the vertex, if possible
857         // if this is a target -- cannot remove
858         if (!(vptr)->is_a_target())
859           remove_vertex_at((vptr),arc->dest());
860       }
861     }
862 
863       // NOTE : more cases to come
864   }
865 }
866 
867 
868 // If v1 and v2 are connected by DGArcDirect and all entry arcs to v1 are of the DGArcDirect type as well,
869 // this function will reattach all arcs extering v1 to v2 and remove v1 from the graph alltogether.
870 // return true if successful, false otherwise
871 bool
remove_vertex_at(const SafePtr<DGVertex> & v1,const SafePtr<DGVertex> & v2)872 DirectedGraph::remove_vertex_at(const SafePtr<DGVertex>& v1, const SafePtr<DGVertex>& v2)
873 {
874 #if DEBUG
875     cout << "remove_vertex_at: replacing " << v1->description() << " with " << v2->description() << endl;
876 #endif
877 
878   // Collect all entry arcs in a container
879   DGVertex::ArcSetType v1_entry;
880   typedef DGVertex::ArcSetType::iterator aiter;
881   typedef DGVertex::ArcSetType::const_iterator aciter;
882   const aciter abegin = v1->first_entry_arc();
883   const aciter aend = v1->plast_entry_arc();
884   // Verify that all entry arcs are DGArcDirect
885   for(aciter a=abegin; a!=aend; ++a) {
886     // See if this is a direct arc -- otherwise cannot do this
887     SafePtr<DGArc> arc = (*a);
888     SafePtr<DGArcDirect> arc_cast = dynamic_pointer_cast<DGArcDirect,DGArc>(arc);
889     if (arc_cast == 0)
890       return false;
891     v1_entry.push_back(*a);
892 #if DEBUG
893     std::cout << "remove_vertex_at: examined v1 entry arc: from " << (*a)->orig()->description() << " to " << (*a)->dest()->description() << std::endl;
894 #endif
895   }
896 
897   // Verify that v1 and v2 are connected by an arc and it is the only arc exiting v1
898   /*if (v1->num_exit_arcs() != 1 || v1->exit_arc(0)->dest() != v2)
899     return false;*/
900 
901   // See if this is a direct arc -- otherwise cannot do this
902   SafePtr<DGArc> arc = *(v1->first_exit_arc());
903   SafePtr<DGArcDirect> arc_cast = dynamic_pointer_cast<DGArcDirect,DGArc>(arc);
904   if (arc_cast == 0)
905     return false;
906 
907   //
908   // OK, now do work!
909   //
910 
911 #if DEBUG || DEBUG_RESTRUCTURE
912   std::cout << "remove_vertex_at: v1" << endl;  v1->print(std::cout);
913   std::cout << "remove_vertex_at: v2" << endl;  v2->print(std::cout);
914 #endif
915 
916   // Reconnect each of v1's entry arcs to v2
917   unsigned int c = 0;
918   for(aiter i=v1_entry.begin(); i != v1_entry.end(); ++i, ++c) {
919     SafePtr<DGVertex> parent = (*i)->orig();
920 #if DEBUG || DEBUG_RESTRUCTURE
921     cout << "remove_vertex_at: replacing arc " << c << " connecting " << parent->description() << " to " << (*i)->dest()->description() << endl;
922     cout << "remove_vertex_at: replacing arc " << c << " connecting " << parent << " to " << (*i)->dest() << endl;
923 #endif
924     SafePtr<DGArcDirect> new_arc(new DGArcDirect(parent,v2));
925 #if DEBUG || DEBUG_RESTRUCTURE
926     cout << "remove_vertex_at:      with arc " << " connecting " << parent->description() << " to " << v2->description() << endl;
927     cout << "remove_vertex_at:      with arc " << " connecting " << parent << " to " << v2 << endl;
928 #endif
929     parent->replace_exit_arc(*i,new_arc);
930 #if DEBUG || DEBUG_RESTRUCTURE
931     cout << "Replaced arcs: parent " << parent->description() << " now connected to " << new_arc->dest()->description() << endl;
932     cout << "                ptr = " << parent << endl;
933     const unsigned int nchildren = parent->num_exit_arcs();
934     cout << "               parent has " << nchildren << " children" << endl;
935     unsigned int c=0;
936     for(aciter a = parent->first_exit_arc(); a!=parent->plast_exit_arc(); ++a, ++c) {
937       cout << "               child " << c << " " << (*a)->dest()->description() << endl;
938       cout << "               child " << c << " ptr = " << (*a)->dest() << endl;
939     }
940 #endif
941   }
942 
943 #if DEBUG || DEBUG_RESTRUCTURE
944   std::cout << "remove_vertex_at: v1" << endl;  v1->print(std::cout);
945   std::cout << "remove_vertex_at: v2" << endl;  v2->print(std::cout);
946 #endif
947 
948   // and fully disconnect this vertex
949   v1->detach();
950 #if DEBUG || DEBUG_RESTRUCTURE
951   std::cout << "remove_vertex_at: detached " << v1->description() << endl;
952 #endif
953 
954   return true;
955 }
956 
957 void
remove_disconnected_vertices()958 DirectedGraph::remove_disconnected_vertices()
959 {
960   typedef vertices::const_iterator citer;
961   typedef vertices::iterator iter;
962   for(iter v=stack_.begin(); v!=stack_.end();) {
963     const ver_ptr& vptr = vertex_ptr(*v);
964     iter vnext = v; ++vnext; // note the next value of iterator before trying to erase
965     if ((vptr)->num_entry_arcs() == 0 && (vptr)->num_exit_arcs() == 0 && !(vptr)->is_a_target()) {
966 #if DEBUG
967       cout << "Trying to erase disconnected vertex " << (vptr)->description() << " num_vertices = " << num_vertices() << endl;
968 #endif
969       try { del_vertex(v); }
970       catch (CannotPerformOperation& v) {
971 #if DEBUG
972         cout << "But couldn't!!!" << endl;
973 #endif
974         throw v;
975       }
976     }
977     // update the iterator
978     v = vnext;
979   }
980 }
981 
982 // generate_code uses this helper function. It's declared in libint2 namespace because
983 // it's also used elsewhere
984 namespace libint2 {
985   std::string
declare_function(const SafePtr<CodeContext> & context,const SafePtr<ImplicitDimensions> & dims,const SafePtr<CodeSymbols> & args,const std::string & tlabel,const std::string & function_descr,std::ostream & decl)986   declare_function(const SafePtr<CodeContext>& context, const SafePtr<ImplicitDimensions>& dims,
987                    const SafePtr<CodeSymbols>& args, const std::string& tlabel, const std::string& function_descr,
988                    std::ostream& decl) {
989 
990     std::string function_name = label_to_funcname(function_descr);
991     function_name = context->label_to_name(function_name);
992 
993     decl << context->code_prefix();
994     std::string func_decl;
995     std::ostringstream oss;
996     oss << context->type_name<void>() << " "
997         << function_name << "(" << context->const_modifier()
998         << context->inteval_type_name(tlabel) << "* inteval";
999     const unsigned int nargs = args->n();
1000     if (nargs > 0) {
1001       // first argument is always the target, which is never a const
1002       oss << ", " << context->type_name<double*>() << " "
1003           << args->symbol(0);
1004       for(unsigned int a=1; a<nargs; a++) {
1005         oss << ", " << context->type_name<const double*>() << " "
1006             << args->symbol(a);
1007       }
1008     }
1009     if (!dims->high_is_static()) {
1010       oss << ", " << context->type_name<int>() << " "
1011           << dims->high()->id();
1012     }
1013     if (!dims->low_is_static()) {
1014       oss << ", " << context->type_name<int>() << " "
1015           <<dims->low()->id();
1016     }
1017     oss << ")";
1018     func_decl = oss.str();
1019 
1020     decl << func_decl << context->end_of_stat() << endl;
1021     decl << context->code_postfix();
1022 
1023     return func_decl;
1024   }
1025 }
1026 
1027 namespace {
1028   template <class Container>
1029   SafePtr<MemBlockSet>
to_memoryblks(Container & vertices)1030   to_memoryblks(Container& vertices) {
1031     SafePtr<MemBlockSet> result(new MemBlockSet);
1032     typedef typename Container::const_iterator citer;
1033     typedef typename Container::iterator iter;
1034     citer end(vertices.end());
1035     for(iter v=vertices.begin(); v!=end; ++v) {
1036       const DirectedGraph::ver_ptr& vptr = vertex_ptr(*v);
1037       result->push_back(MemBlock((vptr)->address(),(vptr)->size(),false,SafePtr<MemBlock>(),SafePtr<MemBlock>()));
1038     }
1039     return result;
1040   }
1041 };
1042 
1043 //
1044 //
1045 //
1046 void
generate_code(const SafePtr<CodeContext> & context,const SafePtr<MemoryManager> & memman,const SafePtr<ImplicitDimensions> & dims,const SafePtr<CodeSymbols> & args,const std::string & label,std::ostream & decl,std::ostream & def)1047 DirectedGraph::generate_code(const SafePtr<CodeContext>& context, const SafePtr<MemoryManager>& memman,
1048                              const SafePtr<ImplicitDimensions>& dims, const SafePtr<CodeSymbols>& args,
1049                              const std::string& label,
1050                              std::ostream& decl, std::ostream& def)
1051 {
1052   LibraryTaskManager& taskmgr = LibraryTaskManager::Instance();
1053   const std::string tlabel = taskmgr.current().label();
1054 
1055   decl << context->copyright();
1056   decl << context->std_header();
1057   std::string comment("This code computes "); comment += label; comment += "\n";
1058   if (context->comments_on())
1059     decl << context->comment(comment) << endl;
1060 
1061   const std::string func_decl = declare_function(context,dims,args,tlabel,label,decl);
1062 
1063   // are there prerequisite vertices that are not precomputed? Then may need to call precompute function
1064   const bool missing_prereqs = this->missing_prerequisites();
1065   std::string func_prereq_name;
1066   std::string func_prereq_decl;
1067   if (missing_prereqs) {
1068     std::ostringstream oss;
1069     func_prereq_name = context->label_to_name(label_to_funcname(label)) + "_prereq";
1070     func_prereq_decl = declare_function(context,dims,args,tlabel,func_prereq_name,oss);
1071   }
1072 
1073   //
1074   // Generate function's definition
1075   //
1076 
1077   def << context->copyright();
1078   // include standard headers
1079   def << context->std_header();
1080   // include declarations for all function calls:
1081   // 1) update func_names_
1082   // 2) (optional) if will compute prerequisites add the name of the function that will compute them
1083   // 3) include their headers into the current definition file
1084   update_func_names();
1085   if (missing_prereqs)
1086     func_names_[func_prereq_name] = true;
1087   for(FuncNameContainer::const_iterator fn=func_names_.begin(); fn!=func_names_.end(); fn++) {
1088     string function_name = (*fn).first;
1089     def << "#include <"
1090         << context->label_to_name(context->cparams()->api_prefix() + function_name)
1091         << ".h>" << endl;
1092   }
1093   def << endl;
1094 
1095   def << context->code_prefix();
1096   def << func_decl << context->open_block() << endl;
1097   def << context->std_function_header();
1098 
1099   // allocate data and assign symbols
1100   context->reset();
1101   // if we vectorize by-line then all data is allocated on Libint's stack
1102   if (context->cparams()->vectorize_by_line())
1103     allocate_mem(memman,dims,0);
1104   // otherwise only arrays go on Libint's stack (scalars are handled by the compiler)
1105   else
1106     allocate_mem(memman,dims,1);
1107   assign_symbols(context,dims);
1108 
1109   // then ...
1110   if (missing_prereqs) { // need to compute prerequisites?
1111 
1112     // if profiling is on, start the next timer, after stopping the current timer (if possible)
1113     if (registry()->current_timer() >= 0) {
1114       const int current_timer = registry()->current_timer();
1115       def << context->macro_if("LIBINT2_CPLUSPLUS_STD >= 2011");
1116       if (current_timer > 0)
1117         def << "inteval->timers->stop(" << current_timer << ");" << std::endl;
1118       def << "inteval->timers->start(" << current_timer+1 << ");" << std::endl;
1119       def << context->macro_endif();
1120     }
1121 
1122     //
1123     // need to zero out space for all missing prerequisites -- prereq evaluator will accumulate into that space
1124     //
1125 
1126     // get prerequisites
1127     PrerequisitesExtractor pe;
1128     this->foreach(pe);
1129     // merge their blocks (likely into one -- allocate_mem should have taken care of that)
1130     SafePtr<MemBlockSet> targetblks = to_memoryblks(pe.vertices);
1131     merge(*targetblks);
1132 
1133     // zero each one out
1134     for(MemBlockSet::iterator
1135         b  = targetblks->begin();
1136         b != targetblks->end();
1137         ++b) {
1138       const size s = b->size();
1139       SafePtr<CTimeEntity<int> > bdim(new CTimeEntity<int>(s));
1140 
1141       SafePtr<Entity> bvecdim;
1142       if (!dims->vecdim_is_static()) {
1143         SafePtr< RTimeEntity<EntityTypes::Int> > vecdim = dynamic_pointer_cast<RTimeEntity<EntityTypes::Int>,Entity>(dims->vecdim());
1144         bvecdim = vecdim * bdim;
1145       }
1146       else {
1147         SafePtr< CTimeEntity<int> > vecdim = dynamic_pointer_cast<CTimeEntity<int>,Entity>(dims->vecdim());
1148         bvecdim = vecdim * bdim;
1149       }
1150       def << "_libint2_static_api_bzero_short_(" << registry()->stack_name() << "+"
1151           << b->address() << "*" << dims->vecdim()->id() << "," << bvecdim->id() << ")" << endl;
1152     }
1153 
1154     // and call the prereq evaluator
1155     SafePtr<Entity> zero(new CTimeEntity<int>(0));
1156     SafePtr<Entity> contr_depth(new RTimeEntity<EntityTypes::Int>("contrdepth"));
1157     std::string contr_index("c");
1158     SafePtr<ForLoop> contr_loop(new ForLoop(context,contr_index,contr_depth,zero));
1159 
1160     def << context->decldef("const int","contrdepth","inteval->contrdepth");
1161     def << contr_loop->open();
1162     def << func_prereq_name << "(inteval+c, "
1163         << registry()->stack_name()
1164         << ")" << context->end_of_stat() << endl;
1165     def << contr_loop->close() << endl;
1166 
1167     // if profiling is on, start the next timer, after stopping the current timer (if possible)
1168     if (registry()->current_timer() >= 0) {
1169       const int current_timer = registry()->current_timer();
1170       def << context->macro_if("LIBINT2_CPLUSPLUS_STD >= 2011");
1171       def << "inteval->timers->stop(" << current_timer+1 << ");" << std::endl;
1172       if (current_timer > 0)
1173         def << "inteval->timers->start(" << current_timer << ");" << std::endl;
1174       def << context->macro_endif();
1175     }
1176 
1177   }
1178 
1179   // if profiling=on and the current timer is outermost, start it
1180   // if this is not outermost timer, it has been started outside of this function
1181   if (registry()->current_timer() == 0) {
1182     def << context->macro_if("LIBINT2_CPLUSPLUS_STD >= 2011");
1183     def << "inteval->timers->start(0);" << std::endl;
1184     def << context->macro_endif();
1185   }
1186 
1187   // now print out the code for this graph
1188   print_def(context,def,dims,args);
1189 
1190   // if profiling=on and the current timer is outermost, stop it
1191   // if this is not outermost timer, it is managed outside of this function
1192   if (registry()->current_timer() == 0) {
1193     def << context->macro_if("LIBINT2_CPLUSPLUS_STD >= 2011");
1194     def << "inteval->timers->stop(0);" << std::endl;
1195     def << context->macro_endif();
1196   }
1197 
1198   def << context->close_block() << endl;
1199   def << context->code_postfix();
1200 }
1201 
allocate_mem(const SafePtr<MemoryManager> & memman,const SafePtr<ImplicitDimensions> & dims,unsigned int min_size_to_alloc)1202 void DirectedGraph::allocate_mem(const SafePtr<MemoryManager>& memman,
1203 const SafePtr<ImplicitDimensions>& dims,
1204 unsigned int min_size_to_alloc)
1205 {
1206   // NOTE does this belong here?
1207   // First, reset tag counters
1208   prepare_to_traverse();
1209 
1210   struct TargetAllocator {
1211     typedef DirectedGraph::targets::const_iterator target_citer;
1212     typedef DirectedGraph::targets::iterator target_iter;
1213     typedef DirectedGraph::size sz;
1214     typedef DirectedGraph::address address;
1215 
1216     const DirectedGraph::targets& targets_;
1217     const SafePtr<MemoryManager>& memman_;
1218     bool all_targets_;
1219     sz size_;
1220 
1221     TargetAllocator(const DirectedGraph::targets& t,
1222     const SafePtr<MemoryManager>& mm,
1223     bool all_targets) :
1224     targets_(t),
1225     memman_(mm),
1226     all_targets_(all_targets)
1227     {
1228       // compute the aggregate size of all targets
1229       target_citer end = targets_.end();
1230       size_ = 0;
1231       for(target_citer t=targets_.begin(); t!=end; ++t) {
1232         const ver_ptr& tptr = vertex_ptr(*t);
1233         if (all_targets_ ||
1234         (!tptr->symbol_set() &&
1235             !tptr->address_set()
1236         )
1237         ) {
1238           size_ += (tptr)->size();
1239         }
1240       }
1241     }
1242 
1243     sz size() const {return size_;}
1244 
1245     void allocate() {
1246       for(target_citer v=targets_.begin(); v!=targets_.end(); ++v) {
1247         const ver_ptr& vptr = vertex_ptr(*v);
1248         if (all_targets_ ||
1249         (!vptr->symbol_set() &&
1250             !vptr->address_set()
1251         )
1252         ) {
1253           const MemoryManager::Address address = memman_->alloc(vptr->size());
1254           // if the vertex is just an alias, pass the address on
1255           if (vptr->refers_to_another())
1256             (*(vptr->first_exit_arc()))->dest()->set_address(address);
1257           else
1258             vptr->set_address(address);
1259         }
1260       }
1261     }
1262   };
1263 
1264   //
1265   // First, allocate all prerequisites that are not precomputed
1266   // since they will be computed before the evaluation of this graph
1267   // they will be targets of the previous computation and will be
1268   // at the beginning of the previous stack. Preallocate them here.
1269   //
1270   if (this->missing_prerequisites()) {
1271     PrerequisitesExtractor pe;
1272     this->foreach(pe);
1273     targets prereqs(pe.vertices.size());
1274     std::copy(pe.vertices.begin(), pe.vertices.end(), prereqs.begin());
1275     // prereqs will be put on the prereq DirectedGraph in the same order
1276     // so use the same allocation mechanism here as for targets
1277     const bool all_targets = true;
1278     TargetAllocator ta(prereqs, memman, all_targets);
1279     ta.allocate();
1280   }
1281 
1282   //
1283   // If need to accumulate targets, special events must happen here.
1284   //
1285   // NOTES on how to handle accumulation
1286   // 1) if all targets are unrolled then need to identify which integrals are part of target sets and use += instead of =
1287   // 2) if no targets are unrolled then allocate extra space for the target quartets and, after code has been generated,
1288   //    accumulate target sets into those
1289   //    EXCEPTION non new space for integral sets is needed if they are decontracted
1290   // 3) if some targets are not unrolled then still need the extra space. The targets which were not unrolled should be handled
1291   //    as usual, i.e. not accumulated -- accumulation happens at the end
1292   if (registry()->accumulate_targets()) {
1293 
1294     // need extra buffers for targets if some are not unrolled ( and not decontracted)
1295     //const bool need_copies_of_targets = nonunrolled_targets(targets_);
1296     const bool need_copies_of_targets = std::find_if(targets_.begin(),
1297                                                      targets_.end(),
1298                                                      [](SafePtr<DGVertex> i){
1299       return !DecontractedIntegralSet()(i) && !UnrolledIntegralSet()(i);
1300     }) != targets_.end();
1301 
1302     iregistry()->accumulate_targets_directly(!need_copies_of_targets);
1303 
1304     if (need_copies_of_targets) {
1305 
1306       const bool all_targets = true;
1307       TargetAllocator ta(targets_, memman, all_targets);
1308       const size size_of_targets = ta.size();
1309       iregistry()->size_of_target_accum(size_of_targets);
1310 
1311       // allocate every target accumulator manually
1312       const address targets_buffer = memman->alloc(size_of_targets);
1313       address curr_ptr = targets_buffer;
1314       for(target_citer t=targets_.begin(); t!=targets_.end(); ++t) {
1315         const ver_ptr& tptr = vertex_ptr(*t);
1316         target_accums_.push_back(curr_ptr);
1317         curr_ptr += (tptr)->size();
1318       }
1319 
1320     } // need copies of targets
1321   } // need to accumulate targets
1322 
1323   // Second, MUST allocate space for all targets whose symbols are not set explicitly
1324   // If a symbol is set means the object is not on stack (e.g. if location of target
1325   // is passed as an argument to set-level function)
1326   // This code ensures that target quartets are persistent, i.e. never overwritten, and can be accumulated into
1327   {
1328     const bool all_targets = false; // only targets without symbols
1329     TargetAllocator ta(targets_, memman, all_targets);
1330     ta.allocate();
1331   }
1332 
1333   //
1334   // How memory management happens:
1335   // Go through the traversal order and at each step tag every child
1336   // Once a child receives same number of tags as the number of parents,
1337   // it can be deallocated
1338   //
1339   SafePtr<DGVertex> vertex = first_to_compute_;
1340   do {
1341     SafePtr<DGArcRR> arcrr;
1342     // memory only needs to be managed for some quantities:
1343     // this conditional decides whether this vertex is on the stack
1344     bool need_to_allocate = true;
1345 
1346     // If symbol is set then the object is not on stack
1347     need_to_allocate &=  not vertex->symbol_set();
1348 
1349     // if address is already set, no need to manage
1350     need_to_allocate &=  not vertex->address_set();
1351 
1352     // precomputed objects don't go on stack
1353     need_to_allocate &= not vertex->precomputed();
1354 
1355     // don't allocate on stack if smaller than or equal to min_size_to_alloc
1356     // two exceptions, however:
1357     // 1) it's a target
1358     // 2) it's an unrolled integral set whose members are not precomputed quantities
1359     //    typically integral sets of size 1 are precomputed and don't need to be on stack,
1360     //    however if they are not, the integral will need to be stored somewhere and the rule for
1361     //    assigning code symbols to members of unrolled integral sets requires the integral set
1362     //    to have an address assigned
1363     need_to_allocate &= ( vertex->size() > min_size_to_alloc ||
1364                           vertex->is_a_target() ||
1365                           (vertex->size() == vertex->num_exit_arcs() &&
1366                               ( (arcrr = dynamic_pointer_cast<DGArcRR,DGArc>(*(vertex->first_exit_arc()))) != 0 ?
1367                                   dynamic_pointer_cast<IntegralSet_to_Integrals_base,RecurrenceRelation>(arcrr->rr()) != 0 :
1368                                   false ) &&
1369                                   !(*(vertex->first_exit_arc()))->dest()->precomputed()
1370                           )
1371     );
1372 
1373     // if this is a member of unrolled integral set whose address or symbol is set, no need to allocate
1374     if (need_to_allocate && vertex->num_entry_arcs() != 0) {
1375       arcrr = dynamic_pointer_cast<DGArcRR,DGArc>(*(vertex->first_entry_arc()));
1376       if (arcrr) {
1377         if (dynamic_pointer_cast<IntegralSet_to_Integrals_base,RecurrenceRelation>(arcrr->rr()) != 0) {
1378           if (arcrr->orig()->symbol_set() || arcrr->orig()->address_set())
1379             need_to_allocate = false;
1380         }
1381       }
1382     }
1383 
1384     if (need_to_allocate) {
1385       MemoryManager::Address addr = memman->alloc(vertex->size());
1386       vertex->set_address(addr);
1387 #if DEBUG
1388       cout << "allocated " << vertex->description() << " at " << addr << " (size=" << vertex->size() << ")" << endl;
1389 #endif
1390 
1391       typedef DGVertex::ArcSetType::const_iterator aciter;
1392       const aciter abegin = vertex->first_exit_arc();
1393       const aciter aend = vertex->plast_exit_arc();
1394       // Verify that all entry arcs are DGArcDirect
1395       for(aciter a=abegin; a!=aend; ++a) {
1396         SafePtr<DGVertex> child = (*a)->dest();
1397         const unsigned int ntags = child->tag();
1398         // Do NOT deallocate if it's a target!
1399         if (ntags == child->num_entry_arcs() && child->address_set() && !child->is_a_target()) {
1400           memman->free(child->address());
1401         }
1402       }
1403     }
1404     vertex = vertex->postcalc();
1405   }while (vertex != 0);
1406 }
1407 
1408 void
assign_symbols(const SafePtr<CodeContext> & context,const SafePtr<ImplicitDimensions> & dims)1409 DirectedGraph::assign_symbols(const SafePtr<CodeContext>& context, const SafePtr<ImplicitDimensions>& dims)
1410 {
1411   std::ostringstream os;
1412   const std::string null_str("");
1413   const std::string stack_name = registry()->stack_name();
1414   // There used to be a compiler/library/bug on OS X? The above failed... Valgrind under Linux shows no memory problems...
1415   //const std::string stack_name("stack");
1416 
1417   // Generate the label for the rank of the low dimension
1418   std::string low_rank = dims->low_label();
1419   std::string veclen = dims->vecdim_label();
1420 
1421   // First, set symbols for all vertices which have address assigned
1422   typedef vertices::const_iterator citer;
1423   typedef vertices::iterator iter;
1424   for(iter v=stack_.begin(); v!=stack_.end(); ++v) {
1425     const ver_ptr& vptr = vertex_ptr(*v);
1426     if (!(vptr)->symbol_set() && (vptr)->address_set()) {
1427       (vptr)->set_symbol(stack_symbol(context,(vptr)->address(),(vptr)->size(),low_rank,veclen,stack_name));
1428     }
1429   }
1430 
1431   // Second, find all nodes which were unrolled using IntegralSet_to_Integrals:
1432   // 1) such nodes do not need symbols generated since they never appear in the code expicitly
1433   // 2) children of such nodes have symbols that depend on the parent's address
1434   // upstream such nodes of size one were aliased ("referred") to their children -- just skip these
1435   for(iter v=stack_.begin(); v!=stack_.end(); ++v) {
1436     const ver_ptr& vptr = vertex_ptr(*v);
1437     if ((vptr)->num_exit_arcs() == 0)
1438       continue;
1439     if (vptr->refers_to_another())
1440       continue;
1441     SafePtr<DGArc> arc = *((vptr)->first_exit_arc());
1442     SafePtr<DGArcRR> arc_rr = dynamic_pointer_cast<DGArcRR,DGArc>(arc);
1443     if (arc_rr == 0)
1444       continue;
1445     SafePtr<RecurrenceRelation> rr = arc_rr->rr();
1446     SafePtr<IntegralSet_to_Integrals_base> iset_to_i = dynamic_pointer_cast<IntegralSet_to_Integrals_base,RecurrenceRelation>(rr);
1447     if (iset_to_i == 0) {
1448       continue;
1449     }
1450     else {
1451       typedef DGVertex::ArcSetType::const_iterator aciter;
1452       const aciter abegin = (vptr)->first_exit_arc();
1453       const aciter aend = (vptr)->plast_exit_arc();
1454       unsigned int c = 0;
1455       // Verify that all entry arcs are DGArcDirect
1456       for(aciter a=abegin; a!=aend; ++a, ++c) {
1457         SafePtr<DGVertex> child = (*a)->dest();
1458         // If the child is precomputed and it's parent symbol is not set -- its symbol will be set as usual
1459         if (!child->precomputed() || (vptr)->symbol_set()) {
1460 
1461           // check if symbol is already set ... this indicates interference with the logic upstream, it should be set here?
1462           if (child->symbol_set()) {
1463             std::cout << "WARNING: symbol for " << child->description() << " (unrolled integral) already set" << std::endl;
1464             continue;
1465           }
1466 
1467           if ((vptr)->address_set()) {
1468             child->set_symbol(stack_symbol(context,(vptr)->address()+c,(vptr)->size(),low_rank,veclen,stack_name));
1469           }
1470           else {
1471             child->set_symbol(stack_symbol(context,c,(vptr)->size(),low_rank,veclen,(vptr)->symbol()));
1472           }
1473         }
1474       }
1475       (vptr)->refer_this_to((*((vptr)->first_exit_arc()))->dest());
1476       (vptr)->reset_symbol();
1477     }
1478   }
1479 
1480   // then process all other symbols, EXCEPT operators
1481   for(iter v=stack_.begin(); v!=stack_.end(); ++v) {
1482     const ver_ptr& vptr = vertex_ptr(*v);
1483 #if DEBUG
1484     cout << "Trying to assign symbol to " << (vptr)->description() << endl;
1485 #endif
1486     if ((vptr)->symbol_set()) {
1487 #if DEBUG
1488       cout << "symbol already set to " << (vptr)->symbol() << endl;
1489 #endif
1490       continue;
1491     }
1492 
1493     // test if the vertex is a static quantity, like a constant
1494     {
1495       typedef CTimeEntity<double> cdouble;
1496       SafePtr<cdouble> ptr_cast = dynamic_pointer_cast<cdouble,DGVertex>((vptr));
1497       if (ptr_cast) {
1498         (vptr)->set_symbol(ptr_cast->label());
1499         continue;
1500       }
1501     }
1502 
1503     // test if the vertex is precomputed runtime quantity, like a geometric parameter
1504     if ((vptr)->precomputed()) {
1505       std::string symbol("inteval->");
1506       symbol += context->label_to_name((vptr)->label());
1507       symbol += "[vi]";
1508       (vptr)->set_symbol(symbol);
1509       continue;
1510     }
1511 
1512     // test if the vertex is other runtime quantity
1513     {
1514       typedef RTimeEntity<double> cdouble;
1515       SafePtr<cdouble> ptr_cast = dynamic_pointer_cast<cdouble,DGVertex>((vptr));
1516       if (ptr_cast) {
1517         (vptr)->set_symbol(ptr_cast->label());
1518         continue;
1519       }
1520     }
1521 
1522     // test if the vertex is some other data that was not allocated on stack
1523 #if 1
1524     if (vptr->size() == 1) {
1525       { // basic integral
1526         typedef IntegralSet<IncableBFSet> intset;
1527         SafePtr<intset> ptr_cast = dynamic_pointer_cast<intset,DGVertex>((vptr));
1528         if (ptr_cast) {
1529           (vptr)->set_symbol(context->unique_name<EntityTypes::FP>());
1530           continue;
1531         }
1532       }
1533     }
1534 #endif
1535 
1536   } // done with everything BUT operators
1537 
1538   // finally, process all operators (start with most recently added vertices since those are
1539   // much more likely to be on the bottom of the graph).
1540   typedef vertices::const_reverse_iterator criter;
1541   typedef vertices::reverse_iterator riter;
1542   for(riter v=stack_.rbegin(); v!=stack_.rend(); ++v) {
1543     ver_ptr& vptr = vertex_ptr(*v);
1544     if (vptr->symbol_set())
1545       continue;
1546 #if DEBUG
1547     cout << "Trying to assign symbol to operator " << (vptr)->description() << endl;
1548 #endif
1549     assign_oper_symbol(context,(vptr));
1550   }
1551 
1552 }
1553 
1554 void
assign_oper_symbol(const SafePtr<CodeContext> & context,SafePtr<DGVertex> & vertex)1555 DirectedGraph::assign_oper_symbol(const SafePtr<CodeContext>& context, SafePtr<DGVertex>& vertex)
1556 {
1557   // do nothing if the vertex has a symbol or is not an operator
1558   if (vertex->symbol_set())
1559     return;
1560 
1561   {
1562     typedef AlgebraicOperator<DGVertex> oper;
1563     SafePtr<oper> ptr_cast = dynamic_pointer_cast<oper,DGVertex>(vertex);
1564     if (ptr_cast) {
1565       // is it in a subtree?
1566       const bool on_a_subtree = (vertex->subtree() != 0);
1567 
1568       // If no -- it will be an automatic variable
1569       if (!on_a_subtree)
1570         vertex->set_symbol(context->unique_name<EntityTypes::FP>());
1571       // else assign symbols to left and right arguments
1572       else {
1573 	typedef DGVertex::ArcSetType::const_iterator aciter;
1574 	aciter arc = ptr_cast->first_exit_arc();
1575         SafePtr<DGVertex> left = (*arc)->dest(); ++arc;
1576         SafePtr<DGVertex> right = (*arc)->dest();
1577         assign_oper_symbol(context,left);
1578         assign_oper_symbol(context,right);
1579 
1580         std::ostringstream oss;
1581         oss << "( " << left->symbol() << " ) "
1582             << ptr_cast->label()
1583             << " ( " << right->symbol() << " )";
1584         vertex->set_symbol(oss.str());
1585       }
1586     }
1587   }
1588 }
1589 
1590 
1591 namespace {
1592   /// returns how many times token appears in str
nfind(const std::string & str,const std::string & token)1593   unsigned int nfind(const std::string& str, const std::string& token)
1594   {
1595     unsigned int nfinds = 0;
1596     typedef std::string::size_type size_type;
1597     size_type current_pos = 0;
1598     while(1) {
1599       current_pos = str.find(token,current_pos);
1600       if (current_pos != std::string::npos) {
1601         ++nfinds;
1602         ++current_pos;
1603       }
1604       else
1605         return nfinds;
1606     }
1607   }
1608 
1609 #define DO_NOT_COUNT_DIV 1
1610   /// Returns the number of FLOPs in an expression
nflops(const std::string & expr)1611   unsigned int nflops(const std::string& expr)
1612   {
1613     static const std::string mul(" * ");
1614     static const std::string div(" / ");
1615     static const std::string plus(" + ");
1616     static const std::string minus(" - ");
1617     const unsigned int nflops =
1618       nfind(expr,mul) +
1619       nfind(expr,plus) +
1620       nfind(expr,minus)
1621 #if !DO_NOT_COUNT_DIV
1622       + nfind(expr,div)
1623 #endif
1624       ;
1625     return nflops;
1626   }
1627 }
1628 
1629 void
print_def(const SafePtr<CodeContext> & context,std::ostream & os,const SafePtr<ImplicitDimensions> & dims,const SafePtr<CodeSymbols> & args)1630 DirectedGraph::print_def(const SafePtr<CodeContext>& context, std::ostream& os,
1631                         const SafePtr<ImplicitDimensions>& dims,
1632                         const SafePtr<CodeSymbols>& args)
1633 {
1634   std::ostringstream oss;
1635   const std::string null_str("");
1636   SafePtr<Entity > ctimeconst_zero(new CTimeEntity<int>(0));
1637 
1638   //
1639   // set stack ... if this function was given any arguments, the first is always stack
1640   //
1641   os << context->decldef(context->type_name<double* const>(),
1642                          "stack",
1643                          (args->n() >= 1) ? args->symbol(0) : registry()->stack_name());
1644 
1645   const bool accumulate_targets_directly = registry()->accumulate_targets() && iregistry()->accumulate_targets_directly();
1646   const bool accumulate_targets_indirectly = registry()->accumulate_targets() && !accumulate_targets_directly;
1647 
1648   //
1649   // To optimize accumulation and setting blocks to zero, need to merge maximally the targets (the accumulation area is one block)
1650   //
1651   SafePtr<MemBlockSet> targetblks;
1652   if (registry()->accumulate_targets()) {
1653     if (accumulate_targets_directly) {
1654       targetblks = to_memoryblks(targets_);
1655       merge(*targetblks);
1656     }
1657     else {
1658       targetblks = SafePtr<MemBlockSet>(new MemBlockSet);
1659       targetblks->push_back(MemBlock(0,iregistry()->size_of_target_accum(),false,SafePtr<MemBlock>(),SafePtr<MemBlock>()));
1660     }
1661   }
1662 
1663   //
1664   // If accumulating integrals, check inteval's zero_out_targets. If set to 1 -- zero out accumulated targets
1665   //
1666 #if LIBINT_ACCUM_INTS
1667   if (registry()->accumulate_targets()) {
1668 
1669     const bool vecdim_is_static = dims->vecdim_is_static();
1670 
1671     os << "if (inteval->zero_out_targets) {" << std::endl;
1672 
1673     typedef MemBlockSet::const_iterator citer;
1674     typedef MemBlockSet::iterator iter;
1675     const citer end = targetblks->end();
1676     for(iter b=targetblks->begin(); b!=end; ++b) {
1677 
1678       size s = b->size();
1679       SafePtr<CTimeEntity<int> > bdim(new CTimeEntity<int>(s));
1680 
1681       SafePtr<Entity> bvecdim;
1682       if (!vecdim_is_static) {
1683 	SafePtr< RTimeEntity<EntityTypes::Int> > vecdim = dynamic_pointer_cast<RTimeEntity<EntityTypes::Int>,Entity>(dims->vecdim());
1684 	bvecdim = vecdim * bdim;
1685       }
1686       else {
1687 	SafePtr< CTimeEntity<int> > vecdim = dynamic_pointer_cast<CTimeEntity<int>,Entity>(dims->vecdim());
1688 	bvecdim = vecdim * bdim;
1689       }
1690 
1691 #if 0
1692       std::string loopvar("i");
1693       ForLoop loop(context,loopvar,bvecdim,ctimeconst_zero);
1694       os << loop.open();
1695       {
1696 	ostringstream oss;
1697 	oss << registry()->stack_name() << "[" << loopvar << "]";
1698 	const std::string zero("0");
1699 	os << context->assign(oss.str(),zero);
1700       }
1701       os << loop.close();
1702 #endif
1703       os << "_libint2_static_api_bzero_short_(" << registry()->stack_name() << "+"
1704 	 << b->address() << "*" << dims->vecdim()->id() << "," << bvecdim->id() << ")" << endl;
1705 
1706     }
1707 
1708     os << "inteval->zero_out_targets = 0;" << std::endl << "}" << std::endl;
1709   }
1710 #endif
1711 
1712   std::string varname("hsi");
1713   SafePtr<ForLoop> hsi_loop(new ForLoop(context,varname,dims->high(),SafePtr<Entity>(new CTimeEntity<int>(0))));
1714   os << hsi_loop->open();
1715 
1716   varname = "lsi";
1717   SafePtr<ForLoop> lsi_loop(new ForLoop(context,varname,dims->low(),SafePtr<Entity>(new CTimeEntity<int>(0))));
1718   os << lsi_loop->open();
1719 
1720   // the vector loop is created outside of the body of the function if
1721   // 1) blockwise vectorization is requested
1722   // and
1723   // 2) this is a purely int-unit code, i.e. there are no explicit RRs on sets in the body
1724   // Otherwise, I create a dummy vector loop with the vector loop index set to 0
1725   const unsigned int max_vector_length = context->cparams()->max_vector_length();
1726   const bool vectorize = (max_vector_length != 1);
1727   const bool vectorize_by_line = context->cparams()->vectorize_by_line();
1728   const bool create_outer_vector_loop = !vectorize_by_line && !cannot_enclose_in_outer_vloop();
1729   varname = "vi";
1730   // outer vector loop
1731   SafePtr<ForLoop> outer_vloop;
1732   // vector loop for each code line
1733   SafePtr<ForLoop> line_vloop;
1734   if (create_outer_vector_loop) {
1735     SafePtr<ForLoop> tmp_vi_loop(new ForLoop(context,varname,dims->vecdim(),SafePtr<Entity>(new CTimeEntity<int>(0))));
1736     outer_vloop = tmp_vi_loop;
1737   }
1738   else {
1739     SafePtr<Entity> unit_dim(new CTimeEntity<int>(1));
1740     SafePtr<ForLoop> tmp_vi_loop(new ForLoop(context,varname,unit_dim,SafePtr<Entity>(new CTimeEntity<int>(0))));
1741     outer_vloop = tmp_vi_loop;
1742     SafePtr<ForLoop> tmp2_vi_loop(new ForLoop(context,varname,dims->vecdim(),SafePtr<Entity>(new CTimeEntity<int>(0))));
1743     line_vloop = tmp2_vi_loop;
1744     // note that both loops use same variable name -- standard C++ scoping rules allow it -- but the outer
1745     // loop will become a declaration of a constant variable
1746   }
1747   os << outer_vloop->open();
1748 
1749   //
1750   // generate code for vertices
1751   //
1752   unsigned int nflops_total = 0;
1753   SafePtr<DGVertex> current_vertex = first_to_compute_;
1754   do {
1755 
1756     // skip if already scheduled
1757     if (current_vertex->scheduled())
1758       goto next;
1759 
1760     // skip if this is a dummy vertex (i.e. refers to someone else)
1761     if (current_vertex->refers_to_another())
1762       goto next;
1763 
1764     // for every vertex that has a defined symbol, hence must be defined in code
1765     if (current_vertex->symbol_set()) {
1766 
1767       // Type declaration if this is an automatic variable
1768       // ids for automatic variables cannot have characters '[' or ']'
1769       const std::string& symbol = current_vertex->symbol();
1770       if (symbol.find('[') == std::string::npos) {
1771         os << context->declare(context->type_name<double> (),
1772                                current_vertex->symbol());
1773 #if CHECK_SAFETY
1774         current_vertex->declared(true);
1775 #endif
1776       }
1777 
1778       // print algebraic expression
1779       {
1780         typedef AlgebraicOperator<DGVertex> oper_type;
1781         SafePtr<oper_type> oper_ptr =
1782             dynamic_pointer_cast<oper_type, DGVertex> (current_vertex);
1783         if (oper_ptr) {
1784 
1785           // If this is an Integral in a target IntegralSet AND
1786           // can accumulate targets directly -- use '+=' instead of '='
1787           const bool accumulate_not_assign = accumulate_targets_directly
1788               && IntegralInTargetIntegralSet()(current_vertex);
1789 
1790           typedef DGVertex::ArcSetType::const_iterator aciter;
1791           aciter a = oper_ptr->first_exit_arc();
1792           auto left_arg = (*a)->dest();
1793           ++a;
1794           auto right_arg = (*a)->dest();
1795 
1796           if (context->comments_on()) {
1797 
1798             oss.str(null_str);
1799             oss << current_vertex->label() << (accumulate_not_assign ? " += "
1800                                                                      : " = ")
1801                 << left_arg->label() << oper_ptr->label() << right_arg->label();
1802             os << context->comment(oss.str()) << endl;
1803           }
1804 
1805           // expression
1806 #if DEBUG
1807           cout << "Generating code for " << current_vertex->description() << endl;
1808           cout << "              ptr = " << current_vertex << endl;
1809           cout << "         left_arg = " << left_arg->description() << endl;
1810           cout << "        right_arg = " << right_arg->description() << endl;
1811 #endif
1812 
1813           // can we generate a composite instruction, like FMA?
1814           // - FMA: yes if:
1815           //     o current_vertex is a multiply
1816           //     o it has one parent and the parent is a +/- operator
1817           //     o parent's other argument has been computed already
1818           bool generate_fma = false;
1819           SafePtr<oper_type> parent_oper_ptr;
1820           SafePtr<DGVertex> fma_other_arg;
1821 #if LIBINT_GENERATE_FMA
1822           {
1823             if (oper_ptr->type() == algebra::OperatorTypes::Times &&
1824                 oper_ptr->num_entry_arcs() == 1) {
1825               parent_oper_ptr =
1826                           dynamic_pointer_cast<oper_type, DGVertex> ( (*(oper_ptr->first_entry_arc()))->orig() );
1827               if (parent_oper_ptr != 0) {
1828                 auto parent_oper_type = parent_oper_ptr->type();
1829                 if (parent_oper_type == algebra::OperatorTypes::Plus ||
1830                     parent_oper_type == algebra::OperatorTypes::Minus) {
1831                   auto arg1 = (*(parent_oper_ptr->first_exit_arc()))->dest();
1832                   auto arg2 = (*(++(parent_oper_ptr->first_exit_arc())))->dest();
1833                   auto other_arg = (arg1 == current_vertex) ? arg2 : arg1;
1834                   const bool other_ready = other_arg->num_exit_arcs() == 0 || other_arg->scheduled();
1835 
1836                   // can do fmadd if other_ready
1837                   // can do fmsub if other_ready and it's arg2
1838                   if (parent_oper_type == algebra::OperatorTypes::Plus)
1839                     generate_fma = other_ready;
1840                   else
1841                     generate_fma = other_ready && (current_vertex == arg1);
1842                   if (generate_fma) {
1843                     //std::cout << context->comment("CAN GENERATE FMA!!!") << endl;
1844                     fma_other_arg = other_arg;
1845 
1846 #if DEBUG
1847                     std::cout << "Generating FMA:" << std::endl;
1848                     std::cout << "parent:\n";
1849                     parent_oper_ptr->print(std::cout);
1850                     std::cout << "current_vertex:\n";
1851                     current_vertex->print(std::cout);
1852                     std::cout << "other_vertex:\n";
1853                     other_arg->print(std::cout);
1854                     std::cout << "arg1:\n";
1855                     arg1->print(std::cout);
1856                     std::cout << "arg2:\n";
1857                     arg2->print(std::cout);
1858                     std::cout << "child1:\n";
1859                     (*(parent_oper_ptr->first_exit_arc()))->dest()->print(std::cout);
1860                     std::cout << "child2:\n";
1861                     (*(++(parent_oper_ptr->first_exit_arc())))->dest()->print(std::cout);
1862 #endif
1863 
1864 
1865                     // may need to declare the variable for the parent
1866                     if (parent_oper_ptr->symbol_set()) {
1867 
1868                       // Type declaration if this is an automatic variable
1869                       // ids for automatic variables cannot have characters '[' or ']'
1870                       const std::string& symbol = parent_oper_ptr->symbol();
1871                       if (symbol.find('[') == std::string::npos) {
1872                         os << context->declare(context->type_name<double> (),
1873                                                parent_oper_ptr->symbol());
1874 #if CHECK_SAFETY
1875                         parent_oper_ptr->declared(true);
1876 #endif
1877                       }
1878                     }
1879 
1880                   }
1881                 }
1882               }
1883             }
1884           }
1885 #endif // LIBINT_GENERATE_FMA=1
1886 
1887           // convert symbols to their vector form if needed
1888           std::string curr_symbol = current_vertex->symbol();
1889           std::string left_symbol = left_arg->symbol();
1890           std::string right_symbol = right_arg->symbol();
1891           std::string parent_symbol = generate_fma ? parent_oper_ptr->symbol() : "";
1892           std::string fma_other_arg_symbol = generate_fma ? fma_other_arg->symbol() : "";
1893           if (vectorize) {
1894             curr_symbol = to_vector_symbol(current_vertex);
1895             left_symbol = to_vector_symbol(left_arg);
1896             right_symbol = to_vector_symbol(right_arg);
1897             if (generate_fma) {
1898               parent_symbol = to_vector_symbol(parent_oper_ptr);
1899               fma_other_arg_symbol = to_vector_symbol(fma_other_arg);
1900             }
1901           }
1902 #if CHECK_SAFETY && 0
1903           bool left_not_declared = left_arg->need_to_compute() && !left_arg->declared();
1904           bool right_not_declared = right_arg->need_to_compute() && !right_arg->declared();
1905           if (left_not_declared || right_not_declared) {
1906             std::cout << "Current vertex:" << endl; current_vertex->print(std::cout);
1907             std::cout << "left arg      :" << endl; left_arg->print(std::cout);
1908             std::cout << "right arg     :" << endl; right_arg->print(std::cout);
1909             if (left_not_declared) throw ProgrammingError("DirectedGraph::print_def() -- left_arg not declared");
1910             if (right_not_declared) throw ProgrammingError("DirectedGraph::print_def() -- right_arg not declared");
1911           }
1912 #endif
1913 
1914           if (vectorize_by_line)
1915             os << line_vloop->open();
1916           // the statement that does the work
1917           {
1918             if (accumulate_not_assign) {
1919 
1920               if (generate_fma) {
1921                 os << context->accumulate_ternary_expr(parent_symbol,
1922                                                        left_symbol,
1923                                                        oper_ptr->label(),
1924                                                        right_symbol,
1925                                                        parent_oper_ptr->label(),
1926                                                        fma_other_arg_symbol
1927                                                        );
1928                 nflops_total += 2; // extra flop due to accumulation + extra flop due to FMA
1929               }
1930               else {
1931                 os << context->accumulate_binary_expr(curr_symbol, left_symbol,
1932                                                       oper_ptr->label(),
1933                                                       right_symbol);
1934                 nflops_total += 1; // extra flop due to accumulation
1935               }
1936 
1937             } else { // assign, not accumulate
1938               if (generate_fma) {
1939                 os << context->assign_ternary_expr(parent_symbol,
1940                                                    left_symbol,
1941                                                    oper_ptr->label(),
1942                                                    right_symbol,
1943                                                    parent_oper_ptr->label(),
1944                                                    fma_other_arg_symbol
1945                 );
1946                 nflops_total += 1; // extra flop due to FMA
1947               }
1948               else {
1949                 os
1950                 << context->assign_binary_expr(curr_symbol, left_symbol,
1951                                                oper_ptr->label(),
1952                                                right_symbol);
1953               }
1954             }
1955 
1956             nflops_total += (1 + nflops(left_symbol) + nflops(right_symbol));
1957           }
1958           if (vectorize_by_line)
1959             os << line_vloop->close();
1960 
1961           // if produced FMA, do not forget to mark the parent scheduled
1962           if (generate_fma) {
1963             parent_oper_ptr->schedule();
1964           }
1965 
1966           goto next;
1967         }
1968       }
1969 
1970       // print simple assignment statement
1971       if (current_vertex->num_exit_arcs() == 1) {
1972         typedef DGArcDirect arc_type;
1973         SafePtr<arc_type>
1974             arc_ptr =
1975                 dynamic_pointer_cast<arc_type, DGArc> (
1976                                                        *(current_vertex->first_exit_arc()));
1977         if (arc_ptr) {
1978 
1979 #if CHECK_SAFETY
1980           current_vertex->declared(true);
1981 
1982           SafePtr<DGVertex> rhs_arg = arc_ptr->dest();
1983           if (!rhs_arg->declared() && rhs_arg->need_to_compute()) {
1984             std::cout << "Current vertex:" << endl; current_vertex->print(std::cout);
1985             std::cout << "rhs_arg      :" << endl; rhs_arg->print(std::cout);
1986             throw ProgrammingError("DirectedGraph::print_def() -- rhs_arg not declared");
1987           }
1988 #endif
1989 
1990           // If this is an Integral in a target IntegralSet AND
1991           // can accumulate targets directly -- use '+=' instead of '='
1992           const bool accumulate_not_assign = accumulate_targets_directly
1993               && IntegralInTargetIntegralSet()(current_vertex);
1994 
1995           if (context->comments_on()) {
1996             oss.str(null_str);
1997             oss << current_vertex->label() << (accumulate_not_assign ? " += "
1998                                                                      : " = ")
1999                 << arc_ptr->dest()->label();
2000             os << context->comment(oss.str()) << endl;
2001           }
2002 
2003           // convert symbols to their vector form if needed
2004           std::string curr_symbol = current_vertex->symbol();
2005           std::string rhs_symbol = arc_ptr->dest()->symbol();
2006           if (vectorize) {
2007             curr_symbol = to_vector_symbol(current_vertex);
2008             rhs_symbol = to_vector_symbol(arc_ptr->dest());
2009           }
2010 
2011           if (vectorize_by_line)
2012             os << line_vloop->open();
2013           if (accumulate_not_assign) {
2014             os << context->accumulate(curr_symbol, rhs_symbol);
2015             nflops_total += nflops(rhs_symbol) + 1; // +1 due to +=
2016           } else {
2017             os << context->assign(curr_symbol, rhs_symbol);
2018             nflops_total += nflops(rhs_symbol);
2019           }
2020           if (vectorize_by_line)
2021             os << line_vloop->close();
2022 
2023           goto next;
2024         }
2025       }
2026 
2027       // print out a recurrence relation
2028       if (current_vertex->num_exit_arcs() != 0) {
2029         // printing a recurrence relation
2030         //std::cout << "DirectedGraph::print_def(): a RR making " << current_vertex->description() << std::endl;
2031         typedef DGArcRR arc_type;
2032         SafePtr<arc_type>
2033             arc_ptr =
2034                 dynamic_pointer_cast<arc_type, DGArc> (
2035                                                        *(current_vertex->first_exit_arc()));
2036         if (arc_ptr) {
2037           SafePtr<RecurrenceRelation> rr = arc_ptr->rr();
2038           os << rr->spfunction_call(context, dims);
2039           nflops_total += rr->nflops();
2040 
2041           goto next;
2042         }
2043       }
2044 
2045 #if 0
2046       {
2047         current_vertex->print(std::cout);
2048         throw std::runtime_error(
2049                                  "DirectedGraph::print_def() -- cannot handle this vertex yet");
2050       }
2051 #endif
2052     }
2053 
2054     next:
2055     current_vertex->schedule();
2056     current_vertex = current_vertex->postcalc();
2057 
2058   } while (current_vertex != 0);
2059 
2060   os << outer_vloop->close();
2061   os << lsi_loop->close();
2062   os << hsi_loop->close();
2063 
2064   //
2065   // Accumulate targets
2066   //
2067   if (accumulate_targets_indirectly) {
2068     os << context->comment("Accumulate target integral sets") << std::endl;
2069     //const std::string& stack_name = registry()->stack_name();
2070     const bool vecdim_is_static = dims->vecdim_is_static();
2071 #if 0
2072     unsigned int vecdim_rank;
2073     if (vecdim_is_static) {
2074       SafePtr<CTimeEntity<int> > cptr = dynamic_pointer_cast<CTimeEntity<int> ,Entity >(dims->vecdim());
2075       vecdim_rank = cptr->value();
2076     }
2077     const std::string times_vecdim("*" + dims->vecdim_label());
2078 #endif
2079 
2080     // Loop over targets
2081     unsigned int curr_target = 0;
2082     for(target_iter t=targets_.begin(); t!=targets_.end(); ++t, ++curr_target) {
2083       const ver_ptr& tptr = vertex_ptr(*t);
2084 
2085       size s = (tptr)->size();
2086       SafePtr<CTimeEntity<int> > bdim(new CTimeEntity<int>(s));
2087 
2088       SafePtr<Entity> bvecdim;
2089       if (!vecdim_is_static) {
2090 	SafePtr< RTimeEntity<EntityTypes::Int> > vecdim = dynamic_pointer_cast<RTimeEntity<EntityTypes::Int>,Entity>(dims->vecdim());
2091 	bvecdim = vecdim * bdim;
2092       }
2093       else {
2094 	SafePtr< CTimeEntity<int> > vecdim = dynamic_pointer_cast<CTimeEntity<int>,Entity>(dims->vecdim());
2095 	bvecdim = vecdim * bdim;
2096       }
2097 
2098       // For now write an explicit loop for each target. In the future should:
2099       // 1) check if all computed targets are adjacent (accumulated targets are adjacent by the virtue of the allocation mechanism)
2100       // 2) check the sizes and insert optimized calls, if possible
2101 
2102       // form a single loop over integrals and vector dimension
2103       // NOTE: single loop suffices because if outer/inner strides are not 1 this block of code should not be executed
2104 
2105 #if 0
2106       SafePtr<Entity> loopmax;
2107       if (vecdim_is_static) {
2108 	const int loopmax_value = s*vecdim_rank;
2109 	loopmax = SafePtr<Entity>(new CTimeEntity<int>(loopmax_value));
2110       }
2111       else {
2112 	ostringstream oss;  oss << s << times_vecdim;
2113 	loopmax = SafePtr<Entity>(new RTimeEntity<EntityTypes::Int>(oss.str()));
2114       }
2115 
2116       std::string loopvar("i");
2117       ForLoop loop(context,loopvar,loopmax,ctimeconst_zero);
2118       os << loop.open();
2119       std::string acctarget;
2120       {
2121 	ostringstream oss;
2122 	oss << stack_name << "[" << target_accums_[curr_target] << times_vecdim << "+" << loopvar << "]";
2123 	acctarget = oss.str();
2124       }
2125       std::string target;
2126       {
2127 	ostringstream oss;
2128 	oss << stack_name << "[" << (tptr)->address() << times_vecdim << "+" << loopvar << "]";
2129 	target = oss.str();
2130       }
2131       os << context->accumulate(acctarget,target);
2132       os << loop.close();
2133 #endif
2134       os << "_libint2_static_api_inc1_short_("
2135 	 << registry()->stack_name() << "+" << target_accums_[curr_target] << "*" << dims->vecdim()->id() << ","
2136 	 << registry()->stack_name() << "+" << (tptr)->address() << "*" << dims->vecdim()->id() << ","
2137 	 << bvecdim->id() << ")" << endl;
2138 
2139       nflops_total += s;
2140     }
2141   }
2142 
2143   // Outside of loops stack symbols don't make sense, so we must define loop variables hsi, lsi, and vi to 0
2144   os << context->decldef(context->type_name<const int>(), "hsi", "0");
2145   os << context->decldef(context->type_name<const int>(), "lsi", "0");
2146   os << context->decldef(context->type_name<const int>(), "vi", "0");
2147 
2148   //
2149   // Now pass back all targets through the inteval object, if needed.
2150   //
2151   if (registry()->return_targets()) {
2152     unsigned int curr_target = 0;
2153     for(target_iter t=targets_.begin(); t!=targets_.end(); ++t, ++curr_target) {
2154       const ver_ptr& tptr = vertex_ptr(*t);
2155       const std::string& symbol = (accumulate_targets_indirectly
2156 				   //                                                                    is this correct?         ???
2157 				   ? stack_symbol(context,target_accums_[curr_target],(tptr)->size(),dims->low_label(),dims->vecdim_label(),registry()->stack_name())
2158 				   : (tptr)->symbol());
2159       os << "inteval->targets[" << curr_target << "] = "
2160 	 << context->value_to_pointer(symbol) << context->end_of_stat() << endl;
2161     }
2162   }
2163 
2164   // Print out the number of flops
2165   oss.str(null_str);
2166   oss << "Number of flops = " << nflops_total;
2167   os << context->comment(oss.str()) << endl;
2168 
2169   if (context->cparams()->count_flops()) {
2170     oss.str(null_str);
2171     oss << nflops_total << " * " << dims->high_label() << " * "
2172         << dims->low_label() << " * "
2173         << dims->vecdim_label();
2174     os << context->assign_binary_expr("inteval->nflops[0]","inteval->nflops[0]","+",oss.str());
2175   }
2176 
2177 }
2178 
2179 void
update_func_names()2180 DirectedGraph::update_func_names()
2181 {
2182   // Loop over all vertices
2183   typedef vertices::const_iterator citer;
2184   typedef vertices::iterator iter;
2185   for(iter v=stack_.begin(); v!=stack_.end(); ++v) {
2186     const ver_ptr& vptr = vertex_ptr(*v);
2187     // for every vertex with children
2188     if ((vptr)->num_exit_arcs() > 0) {
2189       // if it must be computed using a RR
2190       SafePtr<DGArc> arc = *((vptr)->first_exit_arc());
2191       SafePtr<DGArcRR> arcrr = dynamic_pointer_cast<DGArcRR,DGArc>(arc);
2192       if (arcrr != 0) {
2193         SafePtr<RecurrenceRelation> rr = arcrr->rr();
2194         // and the RR is complex (i.e. likely to result in a function call)
2195         if (!rr->is_simple()) {
2196           // add it to the RRStack
2197           func_names_[rr->label()] = true;
2198         }
2199       }
2200     }
2201   }
2202 }
2203 
2204 bool
cannot_enclose_in_outer_vloop() const2205 DirectedGraph::cannot_enclose_in_outer_vloop() const
2206 {
2207   SafePtr<DGVertex> current_vertex = first_to_compute_;
2208   do {
2209     const int nchildren = current_vertex->num_exit_arcs();
2210     if (nchildren > 0) {
2211       arc_ptr aptr = *(current_vertex->first_exit_arc());
2212       SafePtr<DGArcRR> aptr_cast = dynamic_pointer_cast<DGArcRR,arc>(aptr);
2213       // if this is a RR
2214       if (aptr_cast != 0) {
2215         // and a non-trivial one
2216         SafePtr<RecurrenceRelation> rr = aptr_cast->rr();
2217         if (!rr->is_simple()) {
2218           // if this is an Uncontract_Integral call, return false
2219           SafePtr<Uncontract_Integral_base> rr_ucb_ptr = dynamic_pointer_cast<Uncontract_Integral_base,RecurrenceRelation>(rr);
2220           if (rr_ucb_ptr)
2221             return false;
2222           else
2223             return true;
2224         }
2225       }
2226     }
2227     current_vertex = current_vertex->postcalc();
2228   } while (current_vertex != 0);
2229 
2230   return false;
2231 }
2232 
2233 void
find_subtrees()2234 DirectedGraph::find_subtrees()
2235 {
2236   // need to condense expressions?
2237   if (!registry()->condense_expr())
2238     return;
2239 
2240   // Find subtrees by starting from the targets and moving down ...
2241   typedef vertices::const_iterator citer;
2242   typedef vertices::iterator iter;
2243   for(iter v=stack_.begin(); v!=stack_.end(); ++v) {
2244     const ver_ptr& vptr = vertex_ptr(*v);
2245     if ((vptr)->is_a_target() && (vptr)->num_entry_arcs() == 0) {
2246       find_subtrees_from(vptr);
2247     }
2248   }
2249 }
2250 
2251 void
find_subtrees_from(const SafePtr<DGVertex> & v)2252 DirectedGraph::find_subtrees_from(const SafePtr<DGVertex>& v)
2253 {
2254   // is not on a subtree already
2255   if (!v->subtree()) {
2256 
2257     bool useless_subtree = false;
2258 
2259     //
2260     // Subtrees are useless in the following cases:
2261     // 1) root is computed via RRs, not explicitly
2262     //
2263     {
2264       if (v->num_exit_arcs() > 0) {
2265         SafePtr<DGArc> arc = *(v->first_exit_arc());
2266         SafePtr<DGArcRR> arc_rr = dynamic_pointer_cast<DGArcRR,DGArc>(arc);
2267         if (arc_rr)
2268           useless_subtree = true;
2269       }
2270     }
2271 
2272     // create subtree
2273     if (!useless_subtree) {
2274       SafePtr<DRTree> stree = DRTree::CreateRootedAt(v);
2275 
2276       // Remove all trivial subtrees
2277       if (stree) {
2278 #if DISABLE_SUBTREES
2279         if (stree->nvertices() >= 0) {
2280           stree->detach();
2281         }
2282 #else
2283         if (stree->nvertices() < 3) {
2284           stree->detach();
2285         }
2286 #endif
2287       }
2288     }
2289 
2290     // move on to children
2291     typedef DGVertex::ArcSetType::const_iterator aciter;
2292     const aciter abegin = v->first_exit_arc();
2293     const aciter aend = v->plast_exit_arc();
2294     for(aciter a=abegin; a!=aend; ++a) {
2295       find_subtrees_from((*a)->dest());
2296     }
2297   }
2298 }
2299 
2300 namespace {
2301   struct PrerequisiteNotComputed {
operator ()__anonad020a590a11::PrerequisiteNotComputed2302       bool operator()(const DirectedGraph::vertices::value_type& v) {
2303         return !v.second->precomputed() && v.second->num_exit_arcs() == 0;
2304       }
2305   };
2306 }
2307 
2308 /// return true if there are vertices with 0 children but not pre-computed
2309 bool
missing_prerequisites() const2310 DirectedGraph::missing_prerequisites() const {
2311   bool missing_prereqs = false;
2312   if (!this->registry()->ignore_missing_prereqs()) {
2313 #if 0
2314   missing_prereqs =
2315       find_if(this->stack_.begin(), this->stack_.end(), [](const vertices::value_type& v) {
2316                 return v.second->precomputed() == false && v.second->num_exit_arcs() == 0;
2317              }) != this->stack_.end();
2318 #else
2319   PrerequisiteNotComputed pred;
2320   missing_prereqs =
2321       std::find_if(stack_.begin(), stack_.end(), pred) != stack_.end();
2322 #endif
2323   }
2324   return missing_prereqs;
2325 }
2326 
2327 ////
2328 
2329 namespace libint2 {
2330 
2331   namespace {
2332 #if USE_ASSOCCONTAINER_BASED_DIRECTEDGRAPH
2333     typedef DirectedGraph::targets::value_type value_type;
2334     struct __NotUnrolledIntegralSet : public std::unary_function<const value_type&,bool> {
operator ()libint2::__anonad020a590b11::__NotUnrolledIntegralSet2335       bool operator()(const value_type& v) {
2336         return NotUnrolledIntegralSet()(v);
2337       }
2338     };
2339 #endif
2340   };
2341 
2342   bool
nonunrolled_targets(const DirectedGraph::targets & targets)2343   nonunrolled_targets(const DirectedGraph::targets& targets) {
2344     typedef DirectedGraph::target_citer citer;
2345     citer end = targets.end();
2346 #if USE_ASSOCCONTAINER_BASED_DIRECTEDGRAPH
2347     if (end != find_if(targets.begin(),end,__NotUnrolledIntegralSet()))
2348 #else
2349     if (end != find_if(targets.begin(),end,NotUnrolledIntegralSet()))
2350 #endif
2351       return true;
2352     else
2353       return false;
2354   }
2355 
2356   void
extract_symbols(const SafePtr<DirectedGraph> & dg)2357   extract_symbols(const SafePtr<DirectedGraph>& dg)
2358   {
2359     LibraryTaskManager& taskmgr = LibraryTaskManager::Instance();
2360     // symbol extractor
2361     {
2362       SafePtr<ExtractExternSymbols> extractor(new ExtractExternSymbols);
2363       dg->foreach(*extractor);
2364       const ExtractExternSymbols::Symbols& symbols = extractor->symbols();
2365       // pass on to the symbol maintainer of the current task
2366       taskmgr.current().symbols()->add(symbols);
2367 #if DEBUG
2368       // print out the symbols
2369       std::cout << "Recovered symbols from DirectedGraph for " << dg << std::endl;
2370       typedef ExtractExternSymbols::Symbols::const_iterator citer;
2371       citer end = symbols.end();
2372       for(citer t=symbols.begin(); t!=end; ++t)
2373 	std::cout << *t << std::endl;
2374 #endif
2375     }
2376     // RR extractor
2377     {
2378       SafePtr<ExtractRR> extractor(new ExtractRR);
2379       dg->foreach(*extractor);
2380       const ExtractRR::RRList& rrlist = extractor->rrlist();
2381       // pass on to the symbol maintainer of the current task
2382       taskmgr.current().symbols()->add(rrlist);
2383     }
2384   }
2385 
2386   //////////////////////
2387   void
operator ()(const SafePtr<DGVertex> & v)2388   PrerequisitesExtractor::operator()(const SafePtr<DGVertex>& v) {
2389 #if DEBUG
2390     std::cout << "PrerequisitesExtractor: considering " << v->description() << std::endl;
2391     v->print(std::cout);
2392 #endif
2393     if (!v->precomputed() &&
2394         v->num_exit_arcs() == 0) {
2395 #if DEBUG
2396       std::cout << "PrerequisitesExtractor: found candidate " << v->description() << std::endl;
2397 #endif
2398 
2399 #define EXTRACTINTEGRALSETS 1
2400 #if EXTRACTINTEGRALSETS
2401       // if this is an integral that was a member of a shell set, add the parent set to the prerequisite level
2402       bool member_of_shellset = false;
2403       SafePtr<DGVertex> parent_shellset;
2404       typedef DGVertex::ArcSetType::const_iterator citer;
2405       for(citer i=v->entry_arcs().begin(); i != v->entry_arcs().end() && !member_of_shellset; ++i) {
2406         const SafePtr<DGArc>& arc = *i;
2407         SafePtr<DGArcRR> arc_cast = dynamic_pointer_cast<DGArcRR,DGArc>(arc);
2408         if (arc_cast) {
2409           SafePtr<RecurrenceRelation> rr = arc_cast->rr();
2410           SafePtr<IntegralSet_to_Integrals_base> rr_cast = dynamic_pointer_cast<IntegralSet_to_Integrals_base,RecurrenceRelation>(rr);
2411           if (rr_cast) {
2412             member_of_shellset = true;
2413             parent_shellset = rr->rr_target();
2414           }
2415         }
2416       }
2417       if (member_of_shellset) {
2418 #if DEBUG
2419         std::cout << "PrerequisitesExtractor: " << v->description() << " is a member of a shell set, will add that instead"<< std::endl;
2420 #endif
2421         if ( vertices.end() == find(vertices.begin(), vertices.end(), parent_shellset) ) {
2422           vertices.push_back(parent_shellset);
2423 #if DEBUG
2424           std::cout << "PrerequisitesExtractor: extracted " << parent_shellset->description() << std::endl;
2425 #endif
2426         }
2427         else {
2428 #if DEBUG
2429           std::cout << "PrerequisitesExtractor: candidate's parent already extracted" << std::endl;
2430 #endif
2431         }
2432       }
2433       else {
2434         vertices.push_back(v);
2435 #if DEBUG
2436         std::cout << "PrerequisitesExtractor: extracted " << v->description() << std::endl;
2437 #endif
2438       }
2439 
2440 #else
2441 
2442       vertices.push_back(v);
2443 #if DEBUG
2444       std::cout << "PrerequisitesExtractor: extracted " << v->description() << std::endl;
2445 #endif
2446 
2447 #endif
2448     }
2449   }
2450   //////////////////////
2451   void
operator ()(const SafePtr<DGVertex> & v)2452   VertexPrinter::operator()(const SafePtr<DGVertex>& v) {
2453     os << "VertexPrinter: " << v->description() << std::endl;
2454 #if DEBUG
2455     v->print(os);
2456 #endif
2457   }
2458 
2459 };
2460