1 /*
2 * Copyright (C) 2004-2021 Edward F. Valeev
3 *
4 * This file is part of Libint.
5 *
6 * Libint is free software: you can redistribute it and/or modify
7 * it under the terms of the GNU General Public License as published by
8 * the Free Software Foundation, either version 3 of the License, or
9 * (at your option) any later version.
10 *
11 * Libint is distributed in the hope that it will be useful,
12 * but WITHOUT ANY WARRANTY; without even the implied warranty of
13 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14 * GNU General Public License for more details.
15 *
16 * You should have received a copy of the GNU General Public License
17 * along with Libint. If not, see <http://www.gnu.org/licenses/>.
18 *
19 */
20
21 #include <cstdio>
22 #include <functional>
23 #include <utility>
24 #include <fstream>
25 #include <dg.h>
26 #include <rr.h>
27 #include <strategy.h>
28 #include <prefactors.h>
29 #include <codeblock.h>
30 #include <default_params.h>
31 #include <graph_registry.h>
32 #include <global_macros.h>
33 #include <extract.h>
34 #include <algebra.h>
35 #include <task.h>
36 #include <context.h>
37 #include <intset_to_ints.h>
38 #include <uncontract.h>
39 #include <dims.h>
40
41 using namespace std;
42 using namespace libint2;
43
44 #define ONLY_CLONE_IF_DIFF 1
45
46 //// utils first
47
48 namespace {
push(DirectedGraph::VPtrAssociativeContainer & vertices,const DirectedGraph::ver_ptr & v)49 void push(DirectedGraph::VPtrAssociativeContainer& vertices, const DirectedGraph::ver_ptr& v) {
50 DirectedGraph::key_type vkey = libint2::key(*v);
51 vertices.insert(std::make_pair(vkey,v));
52 }
push(DirectedGraph::VPtrSequenceContainer & vertices,const DirectedGraph::ver_ptr & v)53 void push(DirectedGraph::VPtrSequenceContainer& vertices, const DirectedGraph::ver_ptr& v) {
54 vertices.push_back(v);
55 }
56 }
57
58
DirectedGraph()59 DirectedGraph::DirectedGraph() :
60 stack_(), targets_(), target_accums_(), label_("graph"), func_names_(),
61 registry_(SafePtr<GraphRegistry>(new GraphRegistry)),
62 iregistry_(SafePtr<InternalGraphRegistry>(new InternalGraphRegistry)),
63 first_to_compute_()
64 {
65 stack_.clear();
66 targets_.clear();
67 }
68
~DirectedGraph()69 DirectedGraph::~DirectedGraph()
70 {
71 reset();
72 }
73
74 void
append_target(const SafePtr<DGVertex> & target)75 DirectedGraph::append_target(const SafePtr<DGVertex>& target)
76 {
77 target->make_a_target();
78 append_vertex(target);
79 push(targets_,target);
80 }
81
82 SafePtr<DGVertex>
append_vertex(const SafePtr<DGVertex> & vertex)83 DirectedGraph::append_vertex(const SafePtr<DGVertex>& vertex)
84 {
85 // If this vertex is owned by this graph, return it immediately
86 if (vertex->dg() == this) {
87 #if DEBUG
88 std::cout << "append_vertex: vertex " << vertex->label() << " is already on" << endl;
89 #endif
90 return vertex;
91 }
92 auto vcopy_on_graph = add_vertex(vertex);
93 // If this is a new vertex -- tell the vertex who its owner is now
94 if (vcopy_on_graph == vertex)
95 vertex->dg(this);
96 return vcopy_on_graph;
97 }
98
99 SafePtr<DGVertex>
add_vertex(const SafePtr<DGVertex> & vertex)100 DirectedGraph::add_vertex(const SafePtr<DGVertex>& vertex)
101 {
102 // if vertex is precomputed -- can replicate it without consequences
103 // else check if it's on already
104 if (!vertex->precomputed()) {
105 SafePtr<DGVertex> vcopy_on_graph = vertex_is_on(vertex);
106 if (vcopy_on_graph)
107 return vcopy_on_graph;
108 }
109 add_new_vertex(vertex);
110 return vertex;
111 }
112
113 void
add_new_vertex(const SafePtr<DGVertex> & vertex)114 DirectedGraph::add_new_vertex(const SafePtr<DGVertex>& vertex)
115 {
116 #if 0
117 // Resize if using std::vector
118 #if !USE_ASSOCCONTAINER_BASED_DIRECTEDGRAPH
119 if (num_vertices() == stack_.capacity()) {
120 stack_.resize( stack_.capacity() + default_size_ );
121 #if DEBUG
122 cout << "Increased size of DirectedGraph's stack to "
123 << stack_.size() << endl;
124 #endif
125 }
126 #endif
127 #endif
128
129 char label[80]; sprintf(label,"vertex%d",num_vertices());
130 vertex->set_graph_label(label);
131 vertex->dg(this);
132
133 push(stack_,vertex);
134 #if DEBUG
135 cout << "add_new_vertex: added vertex " << vertex->description() << endl;
136 #endif
137
138 return;
139 }
140
141 const SafePtr<DGVertex>&
vertex_is_on(const SafePtr<DGVertex> & vertex) const142 DirectedGraph::vertex_is_on(const SafePtr<DGVertex>& vertex) const
143 {
144 if (vertex->dg() == this)
145 return vertex;
146
147 static SafePtr<DGVertex> null_ptr;
148 #if USE_ASSOCCONTAINER_BASED_DIRECTEDGRAPH
149 typedef vertices::const_iterator citer;
150 typedef vertices::value_type value_type;
151 key_type vkey = key(*vertex);
152 // find the first elemnt with this key and iterate until vertex is found or key changes
153 citer vpos = stack_.find(vkey);
154 const citer end = stack_.end();
155 if (vpos != end) {
156 bool can_find = true;
157 while(can_find) {
158 if (can_find && (vpos->second)->equiv(vertex)) {
159 #if DEBUG
160 std::cout << "vertex_is_on: " << (vpos->second)->label() << std::endl;
161 #endif
162 return vpos->second;
163 }
164 ++vpos;
165 can_find = (vpos->first == vkey) && (vpos != end);
166 }
167 }
168 #else
169 typedef vertices::const_reverse_iterator criter;
170 typedef vertices::reverse_iterator riter;
171 const criter rend = stack_.rend();
172 for(criter v=stack_.rbegin(); v!=rend; ++v) {
173 const ver_ptr& vptr = vertex_ptr(*v);
174 if(vertex->equiv(vptr)) {
175 #if DEBUG
176 std::cout << "vertex_is_on: " << (vptr)->label() << std::endl;
177 #endif
178 return vptr;
179 }
180 }
181 #endif
182 #if DEBUG
183 std::cout << "vertex_is_on: NOT " << (vertex)->label() << std::endl;
184 #endif
185 return null_ptr;
186 }
187
188 namespace {
189 struct __reset_dgvertex {
operator ()__anonad020a590211::__reset_dgvertex190 void operator()(SafePtr<DGVertex>& v) {
191 v->reset();
192 #if DEBUG
193 std::cout << "DirectedGraph::reset: will unregister " << v->label() << std::endl;
194 #endif
195 // remove this vertex from its SingletonManager
196 v->unregister();
197 }
198 };
199 struct __reset_safeptr {
operator ()__anonad020a590211::__reset_safeptr200 void operator()(SafePtr<DGVertex>& v) {
201 v.reset();
202 }
203 };
204 }
205
206 void
del_vertex(vertices::iterator & v)207 DirectedGraph::del_vertex(vertices::iterator& v)
208 {
209 static __reset_dgvertex rv;
210 if (v == stack_.end())
211 throw CannotPerformOperation("DirectedGraph::del_vertex() cannot delete vertex");
212 ver_ptr& vptr = vertex_ptr(*v);
213 // Cannot delete targets. Should I be able to? Probably not
214 if (vptr->is_a_target())
215 throw CannotPerformOperation("DirectedGraph::del_vertex() cannot delete targets");
216 if (vptr->num_exit_arcs() == 0 && vptr->num_entry_arcs() == 0) {
217 #if DEBUG
218 std::cout << "del_vertex: trying to remove " << (vertex_ptr(*v))->label() << std::endl;
219 #endif
220 SafePtr<DGVertex> vptr = vertex_ptr(*v); // keep an instance of the pointer to avoid accidental automatic destruction of the DGVertex object
221 stack_.erase(v);
222 rv(vptr);
223 #if DEBUG
224 std::cout << "del_vertex: successful vertex removal " << std::endl;
225 #endif
226 }
227 else
228 throw CannotPerformOperation("DirectedGraph::del_vertex() cannot delete vertex");
229 }
230
231 namespace{
232 struct __prepare_to_traverse {
operator ()__anonad020a590311::__prepare_to_traverse233 void operator()(SafePtr<DGVertex>& v) {
234 v->prepare_to_traverse();
235 }
236 };
237
sort_children_by_nparents(DGVertex::ArcSetType::const_iterator begin,DGVertex::ArcSetType::const_iterator end)238 std::vector<DGVertex::ArcSetType::value_type> sort_children_by_nparents(DGVertex::ArcSetType::const_iterator begin,
239 DGVertex::ArcSetType::const_iterator end) {
240 // std::sort works only for containers that support random access
241 // std::vector<DGVertex::ArcSetType::value_type> sorted_children;
242 // for(DGVertex::ArcSetType::const_iterator i=begin; i!=end; ++i)
243 // sorted_children.push_back(*i);
244 std::vector<DGVertex::ArcSetType::value_type> sorted_children(begin, end);
245 std::sort(sorted_children.begin(), sorted_children.end(),
246 [](DGVertex::ArcSetType::value_type a,
247 DGVertex::ArcSetType::value_type b) {
248 return a->dest()->num_entry_arcs() < b->dest()->num_entry_arcs();
249 }
250 );
251 return sorted_children;
252 }
253 }
254
255 void
prepare_to_traverse()256 DirectedGraph::prepare_to_traverse()
257 {
258 __prepare_to_traverse __ptt;
259 foreach(__ptt);
260 }
261
262 /**
263 * Recursively traverse depth-first, once a node has been tagged by all of its parents schedule its computation
264 */
265 void
traverse()266 DirectedGraph::traverse()
267 {
268 // Initialization
269 prepare_to_traverse();
270
271 // Start at the targets which don't have parents
272 typedef vertices::const_iterator citer;
273 typedef vertices::iterator iter;
274 for(iter v=stack_.begin(); v!=stack_.end(); ++v) {
275 const ver_ptr& vptr = vertex_ptr(*v);
276 if ((vptr)->is_a_target() && (vptr)->num_entry_arcs() == 0) {
277 // First, since this target doesn't have parents we can schedule its computation
278 schedule_computation(vptr);
279
280 //
281 // traverse the rest of graph starting from each child
282 // traversal will start with children with the fewest parents
283 // an explanation for this heuristic: I want to compute the children with 1 parent after
284 // other children so that I can potentially fold their computation with
285 // addition/subtraction to produce FMA and other composite instructions
286 //
287 {
288 // std::sort works only fro containers that support random access
289 std::vector<DGVertex::ArcSetType::value_type> sorted_children = sort_children_by_nparents(vptr->first_exit_arc(),
290 vptr->plast_exit_arc());
291 for(auto a=sorted_children.begin(); a!=sorted_children.end(); ++a) {
292 traverse_from(*a);
293 }
294 }
295
296 }
297 }
298 }
299
300 void
traverse_from(const SafePtr<DGArc> & arc)301 DirectedGraph::traverse_from(const SafePtr<DGArc>& arc)
302 {
303 SafePtr<DGVertex> orig = arc->orig();
304 SafePtr<DGVertex> dest = arc->dest();
305 #if DEBUG_TRAVERSAL
306 std::cout << "traverse_from: orig = " << orig << " dest = " << dest << endl;
307 #endif
308 // no need to compute if precomputed OR has no children
309 if (dest->precomputed() || dest->num_exit_arcs() == 0) {
310 #if DEBUG_TRAVERSAL
311 std::cout << "traverse from: dest is precomputed" << std::endl;
312 #endif
313 return;
314 }
315 // if has been hit by all parents ...
316 const unsigned int num_tags = dest->tag();
317 #if DEBUG_TRAVERSAL
318 std::cout << "traverse from: tagged dest, ntags = " << num_tags << std::endl;
319 #endif
320 const unsigned int num_parents = dest->num_entry_arcs();
321 if (num_tags == num_parents) {
322
323 // ... check if it is on a subtree ...
324 if (SafePtr<DRTree> stree = dest->subtree()) {
325 // ... if yes, schedule only if it is a root of a subtree
326 if (stree->root() == dest)
327 schedule_computation(dest);
328 }
329 // else schedule
330 else
331 schedule_computation(dest);
332
333 {
334 // std::sort works only fro containers that support random access
335 std::vector<DGVertex::ArcSetType::value_type> sorted_children = sort_children_by_nparents(dest->first_exit_arc(),
336 dest->plast_exit_arc());
337 for(auto a=sorted_children.begin(); a!=sorted_children.end(); ++a) {
338 traverse_from(*a);
339 }
340 }
341
342 }
343 }
344
345 void
schedule_computation(const SafePtr<DGVertex> & vertex)346 DirectedGraph::schedule_computation(const SafePtr<DGVertex>& vertex)
347 {
348 vertex->set_postcalc(first_to_compute_);
349 first_to_compute_ = vertex;
350 #if DEBUG || DEBUG_TRAVERSAL
351 std::cout << "schedule_computation: " << vertex << endl;
352 vertex->print(std::cout);
353 #endif
354 }
355
356
357 void
debug_print_traversal(std::ostream & os) const358 DirectedGraph::debug_print_traversal(std::ostream& os) const
359 {
360 SafePtr<DGVertex> current_vertex = first_to_compute_;
361
362 os << "Debug print of traversal order" << endl;
363
364 do {
365 current_vertex->print(os);
366 current_vertex = current_vertex->postcalc();
367 } while (current_vertex != 0);
368 }
369
370 namespace {
371
372 struct __print_vertices_to_dot {
373 bool symbols;
374 std::ostream& os;
__print_vertices_to_dot__anonad020a590511::__print_vertices_to_dot375 __print_vertices_to_dot(bool s, std::ostream& o) : symbols(s), os(o) {}
operator ()__anonad020a590511::__print_vertices_to_dot376 void operator()(const SafePtr<DGVertex>& v) {
377 os << " " << v->graph_label()
378 << " [ label = \"";
379 if (symbols && v->symbol_set())
380 os << v->symbol();
381 else
382 os << v->label();
383 os << "\"]" << endl;
384 }
385 };
386
387 struct __print_arcs_to_dot {
388 std::ostream& os;
__print_arcs_to_dot__anonad020a590511::__print_arcs_to_dot389 __print_arcs_to_dot(std::ostream& o) : os(o) {}
operator ()__anonad020a590511::__print_arcs_to_dot390 void operator()(const SafePtr<DGVertex>& v) {
391 typedef DGVertex::ArcSetType::const_iterator aciter;
392 const aciter abegin = v->first_exit_arc();
393 const aciter aend = v->plast_exit_arc();
394 for(aciter a=abegin; a!=aend; ++a) {
395 SafePtr<DGVertex> dest = (*a)->dest();
396 os << " " << v->graph_label() << " -> "
397 << dest->graph_label() << endl;
398 }
399 }
400 };
401
402 }
403
404 void
print_to_dot(bool symbols,std::ostream & os) const405 DirectedGraph::print_to_dot(bool symbols, std::ostream& os) const
406 {
407 os << "digraph G {" << endl
408 << " size = \"8,8\"" << endl;
409
410 __print_vertices_to_dot pvtd(symbols,os);
411 foreach(pvtd);
412
413 __print_arcs_to_dot patd(os);
414 foreach(patd);
415
416 // Print traversal order using dotted lines
417 SafePtr<DGVertex> current_vertex = first_to_compute_;
418 if (current_vertex != 0) {
419 do {
420 SafePtr<DGVertex> next = current_vertex->postcalc();
421 if (current_vertex && next) {
422 os << " " << current_vertex->graph_label() << " -> "
423 << next->graph_label() << " [ style = dotted constraint = false ]";
424 }
425 current_vertex = next;
426 } while (current_vertex != 0);
427 }
428
429 os << endl << "}" << endl;
430 }
431
432 void
reset()433 DirectedGraph::reset()
434 {
435 // Reset each vertex, releasing all arcs
436 __reset_dgvertex rv;
437 foreach(rv);
438 __reset_safeptr rptr;
439 foreach(rptr);
440
441 // if everything went OK then empty out stack_ and targets_
442 stack_.clear();
443 targets_.clear();
444 first_to_compute_.reset();
445 func_names_.clear();
446 }
447
448
449 /// Apply a strategy to all vertices not yet computed (i.e. which do not have exit arcs)
450 void
apply(const SafePtr<Strategy> & strategy,const SafePtr<Tactic> & tactic)451 DirectedGraph::apply(const SafePtr<Strategy>& strategy,
452 const SafePtr<Tactic>& tactic)
453 {
454 const SafePtr<DirectedGraph> this_ptr = SafePtr_from_this();
455 typedef vertices::const_iterator citer;
456 typedef vertices::iterator iter;
457 for(iter v=stack_.begin(); v!=stack_.end(); ++v) {
458 const ver_ptr& vptr = vertex_ptr(*v);
459 if ((vptr)->num_exit_arcs() != 0 || (vptr)->precomputed() || !(vptr)->need_to_compute())
460 continue;
461
462 SafePtr<RecurrenceRelation> rr0 = strategy->optimal_rr(this_ptr,(vptr),tactic);
463 if (rr0 == 0)
464 continue;
465
466 // add children to the graph
467 SafePtr<DGVertex> target = rr0->rr_target();
468 const int num_children = rr0->num_children();
469 for(int c=0; c<num_children; c++) {
470 SafePtr<DGVertex> child = rr0->rr_child(c);
471 bool new_vertex = true;
472 SafePtr<DGVertex> dgchild = append_vertex(child);
473 if (dgchild != child) {
474 child = dgchild;
475 new_vertex = false;
476 }
477 SafePtr<DGArc> arc(new DGArcRel<RecurrenceRelation>(target,child,rr0));
478 target->add_exit_arc(arc);
479 if (new_vertex)
480 apply_to(child,strategy,tactic);
481 }
482 }
483 }
484
485 /// Add vertex to graph and apply a strategy to vertex recursively
486 void
apply_to(const SafePtr<DGVertex> & vertex,const SafePtr<Strategy> & strategy,const SafePtr<Tactic> & tactic)487 DirectedGraph::apply_to(const SafePtr<DGVertex>& vertex,
488 const SafePtr<Strategy>& strategy,
489 const SafePtr<Tactic>& tactic)
490 {
491 bool not_yet_computed = !vertex->precomputed() && vertex->need_to_compute() && (vertex->num_exit_arcs() == 0);
492 if (!not_yet_computed)
493 return;
494 SafePtr<RecurrenceRelation> rr0 = strategy->optimal_rr(SafePtr_from_this(),vertex,tactic);
495 if (rr0 == 0)
496 return;
497
498 SafePtr<DGVertex> target = rr0->rr_target();
499 const int num_children = rr0->num_children();
500 for(int c=0; c<num_children; c++) {
501 SafePtr<DGVertex> child = rr0->rr_child(c);
502 bool new_vertex = true;
503 SafePtr<DGVertex> dgchild = append_vertex(child);
504 if (dgchild != child) {
505 child = dgchild;
506 new_vertex = false;
507 }
508 SafePtr<DGArc> arc(new DGArcRel<RecurrenceRelation>(target,child,rr0));
509 try {
510 target->add_exit_arc(arc);
511 }
512 catch (...) {
513 std::cout << "failed to use RR " << rr0->label() << std::endl;
514 throw;
515 }
516 if (new_vertex)
517 apply_to(child,strategy,tactic);
518 }
519 }
520
521 // Optimize out simple recurrence relations
522 void
optimize_rr_out(const SafePtr<CodeContext> & context)523 DirectedGraph::optimize_rr_out(const SafePtr<CodeContext>& context)
524 {
525 replace_rr_with_expr();
526 remove_trivial_arithmetics();
527 handle_trivial_nodes(context);
528 remove_disconnected_vertices();
529 find_subtrees();
530 }
531
532 // Replace recurrence relations with expressions
533 void
replace_rr_with_expr()534 DirectedGraph::replace_rr_with_expr()
535 {
536 typedef vertices::const_iterator citer;
537 typedef vertices::iterator iter;
538 for(iter v=stack_.begin(); v!=stack_.end(); ++v) {
539 const ver_ptr& vptr = vertex_ptr(*v);
540 if ((vptr)->num_exit_arcs()) {
541 SafePtr<DGArc> arc0 = *((vptr)->first_exit_arc());
542 SafePtr<DGArcRR> arc0_cast = dynamic_pointer_cast<DGArcRR,DGArc>(arc0);
543 if (arc0_cast == 0)
544 continue;
545 SafePtr<RecurrenceRelation> rr = arc0_cast->rr();
546
547 // Optimize if the recurrence relation is simple and the target and
548 // children are of the same type
549 if (rr->is_simple() && rr->invariant_type()) {
550
551 #if DEBUG || DEBUG_RESTRUCTURE
552 std::cout << "replace_rr_with_expr: replacing " << rr->label() << endl;
553 std::cout << "replace_rr_with_expr: with " << rr->rr_expr()->description() << endl;
554 std::cout << "replace_rr_with_expr: nchildren = " << rr->num_children() << endl;
555 for(unsigned int c=0; c<rr->num_children(); ++c) {
556 std::cout << "replace_rr_with_expr: child " << c << " " << rr->rr_child(c)->description() << endl;
557 }
558 #endif
559
560 // Remove arcs connecting this vertex to children
561 (vptr)->del_exit_arcs();
562
563 // and instead insert the numerical expression
564 SafePtr<RecurrenceRelation::ExprType> rr_expr = rr->rr_expr();
565 SafePtr<DGVertex> expr_vertex = static_pointer_cast<RecurrenceRelation::ExprType,DGVertex>(rr_expr);
566 expr_vertex = insert_expr_at((vptr),rr_expr);
567 SafePtr<DGArc> arc(new DGArcDirect((vptr),expr_vertex));
568 (vptr)->add_exit_arc(arc);
569
570 }
571 }
572 }
573 }
574
575
576 //
577 // This function is very tricky at the moment. The operands have to be added before the operator
578 // such that operands are guaranteed to be on graph before the operator. This way ExprType::equiv
579 // can simply compare operand pointers and the operator type (very cheap operations compared
580 // to the fully recursive explicit comparison).
581 //
582 SafePtr<DGVertex>
insert_expr_at(const SafePtr<DGVertex> & where,const SafePtr<RecurrenceRelation::ExprType> & expr)583 DirectedGraph::insert_expr_at(const SafePtr<DGVertex>& where, const SafePtr<RecurrenceRelation::ExprType>& expr)
584 {
585 #if DEBUG
586 cout << "insert_expr_at: " << expr->description() << endl;
587 #endif
588
589 typedef RecurrenceRelation::ExprType ExprType;
590 SafePtr<DGVertex> expr_vertex = static_pointer_cast<DGVertex,ExprType>(expr);
591
592 // If the expression is already on then return it
593 if (expr->dg() == this)
594 return expr_vertex;
595
596 SafePtr<DGVertex> left_oper = expr->left();
597 SafePtr<DGVertex> right_oper = expr->right();
598 bool new_left = (left_oper->dg() != this); //
599 bool new_right = (right_oper->dg() != this); //
600 bool need_to_clone = false; // clone expression if (parts of) the expression was found on the graph
601
602 // See if left operand is also an operator
603 const SafePtr<ExprType> left_cast = dynamic_pointer_cast<ExprType,DGVertex>(left_oper);
604 // if yes -- add it to the graph recursively
605 if (left_cast) {
606 left_oper = insert_expr_at(expr_vertex,left_cast);
607 }
608 // else add it directly
609 else {
610 left_oper = append_vertex(left_oper);
611 }
612 #if ONLY_CLONE_IF_DIFF
613 if (left_oper != expr->left()) {
614 #if DEBUG
615 std::cout << "insert_expr_at: append(left) != left" << std::endl;
616 #endif
617 need_to_clone = true;
618 new_left = false;
619 }
620 #else
621 #error "ONLY_CLONE_IF_DIFF must be true"
622 #endif
623
624 // See if right operand is also an operator
625 SafePtr<ExprType> right_cast = dynamic_pointer_cast<ExprType,DGVertex>(right_oper);
626 // if yes -- add it to the graph recursively
627 if (right_cast) {
628 right_oper = insert_expr_at(expr_vertex,right_cast);
629 }
630 // else add it directly
631 else {
632 right_oper = append_vertex(right_oper);
633 }
634 #if ONLY_CLONE_IF_DIFF
635 if (right_oper != expr->right()) {
636 #if DEBUG
637 std::cout << "insert_expr_at: append(right) != right" << std::endl;
638 #endif
639 need_to_clone = true;
640 new_right = false;
641 }
642 #else
643 #error "ONLY_CLONE_IF_DIFF must be true"
644 #endif
645
646 if (need_to_clone) {
647 SafePtr<ExprType> expr_new(new ExprType(expr,left_oper,right_oper));
648 expr_vertex = static_pointer_cast<DGVertex,ExprType>(expr_new);
649 #if DEBUG
650 int nchildren = expr->num_exit_arcs();
651 cout << "insert_expr_at: cloned AlgebraicOperator with " << expr->num_exit_arcs() << " children" << endl;
652 if (nchildren) {
653 cout << "Left: " << expr->left()->description() << endl;
654 cout << "Right: " << expr->right()->description() << endl;
655 }
656 #endif
657 }
658
659 SafePtr<DGVertex> dgexpr_vertex = expr_vertex;
660 if (new_left || new_right)
661 add_new_vertex(expr_vertex);
662 else {
663 const bool do_cse = registry()->do_cse();
664 if (do_cse) {
665 #if DEBUG
666 std::cout << "insert_expr_at: appending vertex " << expr_vertex->description() << std::endl;
667 #endif
668 dgexpr_vertex = append_vertex(expr_vertex);
669 }
670 else {
671 add_new_vertex(expr_vertex);
672 }
673 if (expr_vertex != dgexpr_vertex) {
674 if (new_left || new_right) {
675 cout << "Problem detected: AlgebraicOperator is found on the stack but one of its operands was new" << endl;
676 cout << expr_vertex->description() << endl;
677 cout << dgexpr_vertex->description() << endl;
678 throw std::runtime_error("DirectedGraph::insert_expr_at() -- vertex is not new but one of the operands is");
679 }
680 }
681 expr_vertex = dgexpr_vertex;
682 }
683 SafePtr<DGArc> left_arc(new DGArcDirect(expr_vertex,left_oper));
684 expr_vertex->add_exit_arc(left_arc);
685 SafePtr<DGArc> right_arc(new DGArcDirect(expr_vertex,right_oper));
686 expr_vertex->add_exit_arc(right_arc);
687 #if DEBUG
688 cout << "insert_expr_at: added arc between " << where->description() << " and " << expr_vertex->description() << endl;
689 #endif
690
691 return expr_vertex;
692 }
693
694 // Replace recurrence relations with expressions
695 void
remove_trivial_arithmetics()696 DirectedGraph::remove_trivial_arithmetics()
697 {
698 using libint2::prefactor::Scalar;
699 const SafePtr< CTimeEntity<double> > const_one_point_zero = Scalar(1.0);
700 const SafePtr< CTimeEntity<double> > const_zero_point_zero = Scalar(0.0);
701 typedef vertices::const_iterator citer;
702 typedef vertices::iterator iter;
703 for(iter v=stack_.begin(); v!=stack_.end(); ++v) {
704 const ver_ptr& vptr = vertex_ptr(*v);
705 SafePtr< AlgebraicOperator<DGVertex> > oper_cast = dynamic_pointer_cast<AlgebraicOperator<DGVertex>,DGVertex>((vptr));
706 if (oper_cast) {
707
708 //std::cout << "cast " << vptr->description() << " to " << oper_cast->description() << std::endl;
709 typedef DGVertex::ArcSetType::const_iterator aciter;
710 aciter a = oper_cast->first_exit_arc();
711 SafePtr<DGVertex> left = (*a)->dest(); ++a;
712 SafePtr<DGVertex> right = oper_cast->num_exit_arcs()>1 ? (*a)->dest() : left; // num_exit_arcs==1 is a corner case, e.g. 1*1
713
714 using libint2::algebra::OperatorTypes;
715
716 // 1.0 * x = x || 0.0 + x = x
717 if (left->num_entry_arcs() == 1 &&
718 ((oper_cast->type() == OperatorTypes::Times && left->equiv(const_one_point_zero)) ||
719 (oper_cast->type() == OperatorTypes::Plus && left->equiv(const_zero_point_zero)) )
720 ) {
721 #if DEBUG
722 const bool success = remove_vertex_at((vptr),right);
723 if (success)
724 cout << "Removed vertex " << (vptr)->description() << endl;
725 #else
726 remove_vertex_at((vptr),right);
727 #endif
728 }
729 // x * 1.0 = x || x + 0.0 = x
730 else
731 if (right->num_entry_arcs() == 1 &&
732 ((oper_cast->type() == OperatorTypes::Times && right->equiv(const_one_point_zero)) ||
733 (oper_cast->type() == OperatorTypes::Plus && right->equiv(const_zero_point_zero)) )
734 ) {
735 #if DEBUG
736 const bool success = remove_vertex_at((vptr),left);
737 if (success)
738 cout << "Removed vertex " << (vptr)->description() << endl;
739 #else
740 remove_vertex_at((vptr),left);
741 #endif
742 }
743
744 // NOTE : more cases to come
745 }
746 }
747 }
748
749 namespace {
stack_symbol(const SafePtr<CodeContext> & ctext,const DGVertex::Address & address,const DGVertex::Size & size,const std::string & low_rank,const std::string & veclen,const std::string & prefix)750 std::string stack_symbol(const SafePtr<CodeContext>& ctext, const DGVertex::Address& address, const DGVertex::Size& size,
751 const std::string& low_rank, const std::string& veclen,
752 const std::string& prefix)
753 {
754 ostringstream oss;
755 std::string stack_address = ctext->stack_address(address);
756 oss << prefix << "[((hsi*" << size << "+"
757 << stack_address << ")*" << low_rank << "+lsi)*"
758 << veclen << "]";
759 return oss.str();
760 }
761
762 /// Returns a "vector" form of stack symbol, e.g. converts libint->stack[x] to libint->stack[x+vi]
to_vector_symbol(const SafePtr<DGVertex> & v)763 inline std::string to_vector_symbol(const SafePtr<DGVertex>& v)
764 {
765 std::string::size_type current_pos = 0;
766 std::string symb = v->symbol();
767 // replace repeatedly until the string is exhausted
768 while(current_pos != std::string::npos) {
769
770 // find "[" first
771 const std::string left_braket("[");
772 std::string::size_type where = symb.find(left_braket,current_pos);
773 current_pos = where;
774 // if the prefix indicating a stack symbol found:
775 // 1) make sure vi doesn't appear between the brakets
776 // 2) replace "]" with "+vi]"
777 if (where != std::string::npos) {
778 const std::string right_braket("]");
779 std::string::size_type where = symb.find(right_braket,current_pos);
780 if (where == std::string::npos)
781 throw logic_error("to_vector_symbol() -- address is set but no right braket found");
782
783 const std::string forbidden("vi");
784 std::string::size_type pos = symb.find(forbidden,current_pos);
785 if (pos == std::string::npos || pos > where) {
786 const std::string what_to_add("+vi");
787 symb.insert(where,what_to_add);
788 current_pos = where + 4;
789 }
790 else {
791 current_pos = where + 1;
792 }
793 }
794 } // end of while
795 return symb;
796 }
797 };
798
799 //
800 // Handles "trivial" nodes. A node is trivial is it satisfies the following conditions:
801 // 0) not a target
802 // 1) has only one child
803 // 2) the exit arc is of a trivial type (DGArcDirect or IntegralSet_to_Integral applied to node of size 1)
804 //
805 // By "handling" I mean either removing the node from the graph or making a node refer to another node so that
806 // no code is generated for it.
807 //
808 void
handle_trivial_nodes(const SafePtr<CodeContext> & context)809 DirectedGraph::handle_trivial_nodes(const SafePtr<CodeContext>& context)
810 {
811 typedef vertices::const_iterator citer;
812 typedef vertices::iterator iter;
813 for(iter v=stack_.begin(); v!=stack_.end(); ++v) {
814 const ver_ptr& vptr = vertex_ptr(*v);
815 // or if has more than 1 child
816 if ((vptr)->num_exit_arcs() != 1)
817 continue;
818 SafePtr<DGArc> arc = *((vptr)->first_exit_arc());
819
820 // Is the exit arc DGArcRel<IntegralSet_to_Integrals> and (vptr)->size() == 1?
821 if ((vptr)->size() == 1) {
822 SafePtr<DGArcRR> arc_cast = dynamic_pointer_cast<DGArcRR,DGArc>(arc);
823 if (arc_cast) {
824 SafePtr<RecurrenceRelation> rr = arc_cast->rr();
825 SafePtr<IntegralSet_to_Integrals_base> rr_cast = dynamic_pointer_cast<IntegralSet_to_Integrals_base,RecurrenceRelation>(rr);
826 if (rr_cast) {
827 SafePtr<DGVertex> child = arc->dest();
828
829 //if (child->symbol_set() == false)
830 {
831 const std::string stack_name("stack");
832 const SafePtr<ImplicitDimensions>& dims = ImplicitDimensions::default_dims();
833 std::string low_rank = dims->low_label();
834 std::string veclen = dims->vecdim_label();
835
836 if ((vptr)->address_set()) {
837 child->set_symbol(stack_symbol(context,(vptr)->address()+0,(vptr)->size(),low_rank,veclen,stack_name));
838 }
839 if ((vptr)->symbol_set()) {
840 child->set_symbol(stack_symbol(context,0,(vptr)->size(),low_rank,veclen,(vptr)->symbol()));
841 }
842 }
843
844 vptr->refer_this_to(child);
845 vptr->reset_symbol();
846
847 }
848 }
849 }
850
851
852 // Is the exit arc DGArcDirect?
853 {
854 SafePtr<DGArcDirect> arc_cast = dynamic_pointer_cast<DGArcDirect,DGArc>(arc);
855 if (arc_cast) {
856 // remove the vertex, if possible
857 // if this is a target -- cannot remove
858 if (!(vptr)->is_a_target())
859 remove_vertex_at((vptr),arc->dest());
860 }
861 }
862
863 // NOTE : more cases to come
864 }
865 }
866
867
868 // If v1 and v2 are connected by DGArcDirect and all entry arcs to v1 are of the DGArcDirect type as well,
869 // this function will reattach all arcs extering v1 to v2 and remove v1 from the graph alltogether.
870 // return true if successful, false otherwise
871 bool
remove_vertex_at(const SafePtr<DGVertex> & v1,const SafePtr<DGVertex> & v2)872 DirectedGraph::remove_vertex_at(const SafePtr<DGVertex>& v1, const SafePtr<DGVertex>& v2)
873 {
874 #if DEBUG
875 cout << "remove_vertex_at: replacing " << v1->description() << " with " << v2->description() << endl;
876 #endif
877
878 // Collect all entry arcs in a container
879 DGVertex::ArcSetType v1_entry;
880 typedef DGVertex::ArcSetType::iterator aiter;
881 typedef DGVertex::ArcSetType::const_iterator aciter;
882 const aciter abegin = v1->first_entry_arc();
883 const aciter aend = v1->plast_entry_arc();
884 // Verify that all entry arcs are DGArcDirect
885 for(aciter a=abegin; a!=aend; ++a) {
886 // See if this is a direct arc -- otherwise cannot do this
887 SafePtr<DGArc> arc = (*a);
888 SafePtr<DGArcDirect> arc_cast = dynamic_pointer_cast<DGArcDirect,DGArc>(arc);
889 if (arc_cast == 0)
890 return false;
891 v1_entry.push_back(*a);
892 #if DEBUG
893 std::cout << "remove_vertex_at: examined v1 entry arc: from " << (*a)->orig()->description() << " to " << (*a)->dest()->description() << std::endl;
894 #endif
895 }
896
897 // Verify that v1 and v2 are connected by an arc and it is the only arc exiting v1
898 /*if (v1->num_exit_arcs() != 1 || v1->exit_arc(0)->dest() != v2)
899 return false;*/
900
901 // See if this is a direct arc -- otherwise cannot do this
902 SafePtr<DGArc> arc = *(v1->first_exit_arc());
903 SafePtr<DGArcDirect> arc_cast = dynamic_pointer_cast<DGArcDirect,DGArc>(arc);
904 if (arc_cast == 0)
905 return false;
906
907 //
908 // OK, now do work!
909 //
910
911 #if DEBUG || DEBUG_RESTRUCTURE
912 std::cout << "remove_vertex_at: v1" << endl; v1->print(std::cout);
913 std::cout << "remove_vertex_at: v2" << endl; v2->print(std::cout);
914 #endif
915
916 // Reconnect each of v1's entry arcs to v2
917 unsigned int c = 0;
918 for(aiter i=v1_entry.begin(); i != v1_entry.end(); ++i, ++c) {
919 SafePtr<DGVertex> parent = (*i)->orig();
920 #if DEBUG || DEBUG_RESTRUCTURE
921 cout << "remove_vertex_at: replacing arc " << c << " connecting " << parent->description() << " to " << (*i)->dest()->description() << endl;
922 cout << "remove_vertex_at: replacing arc " << c << " connecting " << parent << " to " << (*i)->dest() << endl;
923 #endif
924 SafePtr<DGArcDirect> new_arc(new DGArcDirect(parent,v2));
925 #if DEBUG || DEBUG_RESTRUCTURE
926 cout << "remove_vertex_at: with arc " << " connecting " << parent->description() << " to " << v2->description() << endl;
927 cout << "remove_vertex_at: with arc " << " connecting " << parent << " to " << v2 << endl;
928 #endif
929 parent->replace_exit_arc(*i,new_arc);
930 #if DEBUG || DEBUG_RESTRUCTURE
931 cout << "Replaced arcs: parent " << parent->description() << " now connected to " << new_arc->dest()->description() << endl;
932 cout << " ptr = " << parent << endl;
933 const unsigned int nchildren = parent->num_exit_arcs();
934 cout << " parent has " << nchildren << " children" << endl;
935 unsigned int c=0;
936 for(aciter a = parent->first_exit_arc(); a!=parent->plast_exit_arc(); ++a, ++c) {
937 cout << " child " << c << " " << (*a)->dest()->description() << endl;
938 cout << " child " << c << " ptr = " << (*a)->dest() << endl;
939 }
940 #endif
941 }
942
943 #if DEBUG || DEBUG_RESTRUCTURE
944 std::cout << "remove_vertex_at: v1" << endl; v1->print(std::cout);
945 std::cout << "remove_vertex_at: v2" << endl; v2->print(std::cout);
946 #endif
947
948 // and fully disconnect this vertex
949 v1->detach();
950 #if DEBUG || DEBUG_RESTRUCTURE
951 std::cout << "remove_vertex_at: detached " << v1->description() << endl;
952 #endif
953
954 return true;
955 }
956
957 void
remove_disconnected_vertices()958 DirectedGraph::remove_disconnected_vertices()
959 {
960 typedef vertices::const_iterator citer;
961 typedef vertices::iterator iter;
962 for(iter v=stack_.begin(); v!=stack_.end();) {
963 const ver_ptr& vptr = vertex_ptr(*v);
964 iter vnext = v; ++vnext; // note the next value of iterator before trying to erase
965 if ((vptr)->num_entry_arcs() == 0 && (vptr)->num_exit_arcs() == 0 && !(vptr)->is_a_target()) {
966 #if DEBUG
967 cout << "Trying to erase disconnected vertex " << (vptr)->description() << " num_vertices = " << num_vertices() << endl;
968 #endif
969 try { del_vertex(v); }
970 catch (CannotPerformOperation& v) {
971 #if DEBUG
972 cout << "But couldn't!!!" << endl;
973 #endif
974 throw v;
975 }
976 }
977 // update the iterator
978 v = vnext;
979 }
980 }
981
982 // generate_code uses this helper function. It's declared in libint2 namespace because
983 // it's also used elsewhere
984 namespace libint2 {
985 std::string
declare_function(const SafePtr<CodeContext> & context,const SafePtr<ImplicitDimensions> & dims,const SafePtr<CodeSymbols> & args,const std::string & tlabel,const std::string & function_descr,std::ostream & decl)986 declare_function(const SafePtr<CodeContext>& context, const SafePtr<ImplicitDimensions>& dims,
987 const SafePtr<CodeSymbols>& args, const std::string& tlabel, const std::string& function_descr,
988 std::ostream& decl) {
989
990 std::string function_name = label_to_funcname(function_descr);
991 function_name = context->label_to_name(function_name);
992
993 decl << context->code_prefix();
994 std::string func_decl;
995 std::ostringstream oss;
996 oss << context->type_name<void>() << " "
997 << function_name << "(" << context->const_modifier()
998 << context->inteval_type_name(tlabel) << "* inteval";
999 const unsigned int nargs = args->n();
1000 if (nargs > 0) {
1001 // first argument is always the target, which is never a const
1002 oss << ", " << context->type_name<double*>() << " "
1003 << args->symbol(0);
1004 for(unsigned int a=1; a<nargs; a++) {
1005 oss << ", " << context->type_name<const double*>() << " "
1006 << args->symbol(a);
1007 }
1008 }
1009 if (!dims->high_is_static()) {
1010 oss << ", " << context->type_name<int>() << " "
1011 << dims->high()->id();
1012 }
1013 if (!dims->low_is_static()) {
1014 oss << ", " << context->type_name<int>() << " "
1015 <<dims->low()->id();
1016 }
1017 oss << ")";
1018 func_decl = oss.str();
1019
1020 decl << func_decl << context->end_of_stat() << endl;
1021 decl << context->code_postfix();
1022
1023 return func_decl;
1024 }
1025 }
1026
1027 namespace {
1028 template <class Container>
1029 SafePtr<MemBlockSet>
to_memoryblks(Container & vertices)1030 to_memoryblks(Container& vertices) {
1031 SafePtr<MemBlockSet> result(new MemBlockSet);
1032 typedef typename Container::const_iterator citer;
1033 typedef typename Container::iterator iter;
1034 citer end(vertices.end());
1035 for(iter v=vertices.begin(); v!=end; ++v) {
1036 const DirectedGraph::ver_ptr& vptr = vertex_ptr(*v);
1037 result->push_back(MemBlock((vptr)->address(),(vptr)->size(),false,SafePtr<MemBlock>(),SafePtr<MemBlock>()));
1038 }
1039 return result;
1040 }
1041 };
1042
1043 //
1044 //
1045 //
1046 void
generate_code(const SafePtr<CodeContext> & context,const SafePtr<MemoryManager> & memman,const SafePtr<ImplicitDimensions> & dims,const SafePtr<CodeSymbols> & args,const std::string & label,std::ostream & decl,std::ostream & def)1047 DirectedGraph::generate_code(const SafePtr<CodeContext>& context, const SafePtr<MemoryManager>& memman,
1048 const SafePtr<ImplicitDimensions>& dims, const SafePtr<CodeSymbols>& args,
1049 const std::string& label,
1050 std::ostream& decl, std::ostream& def)
1051 {
1052 LibraryTaskManager& taskmgr = LibraryTaskManager::Instance();
1053 const std::string tlabel = taskmgr.current().label();
1054
1055 decl << context->copyright();
1056 decl << context->std_header();
1057 std::string comment("This code computes "); comment += label; comment += "\n";
1058 if (context->comments_on())
1059 decl << context->comment(comment) << endl;
1060
1061 const std::string func_decl = declare_function(context,dims,args,tlabel,label,decl);
1062
1063 // are there prerequisite vertices that are not precomputed? Then may need to call precompute function
1064 const bool missing_prereqs = this->missing_prerequisites();
1065 std::string func_prereq_name;
1066 std::string func_prereq_decl;
1067 if (missing_prereqs) {
1068 std::ostringstream oss;
1069 func_prereq_name = context->label_to_name(label_to_funcname(label)) + "_prereq";
1070 func_prereq_decl = declare_function(context,dims,args,tlabel,func_prereq_name,oss);
1071 }
1072
1073 //
1074 // Generate function's definition
1075 //
1076
1077 def << context->copyright();
1078 // include standard headers
1079 def << context->std_header();
1080 // include declarations for all function calls:
1081 // 1) update func_names_
1082 // 2) (optional) if will compute prerequisites add the name of the function that will compute them
1083 // 3) include their headers into the current definition file
1084 update_func_names();
1085 if (missing_prereqs)
1086 func_names_[func_prereq_name] = true;
1087 for(FuncNameContainer::const_iterator fn=func_names_.begin(); fn!=func_names_.end(); fn++) {
1088 string function_name = (*fn).first;
1089 def << "#include <"
1090 << context->label_to_name(context->cparams()->api_prefix() + function_name)
1091 << ".h>" << endl;
1092 }
1093 def << endl;
1094
1095 def << context->code_prefix();
1096 def << func_decl << context->open_block() << endl;
1097 def << context->std_function_header();
1098
1099 // allocate data and assign symbols
1100 context->reset();
1101 // if we vectorize by-line then all data is allocated on Libint's stack
1102 if (context->cparams()->vectorize_by_line())
1103 allocate_mem(memman,dims,0);
1104 // otherwise only arrays go on Libint's stack (scalars are handled by the compiler)
1105 else
1106 allocate_mem(memman,dims,1);
1107 assign_symbols(context,dims);
1108
1109 // then ...
1110 if (missing_prereqs) { // need to compute prerequisites?
1111
1112 // if profiling is on, start the next timer, after stopping the current timer (if possible)
1113 if (registry()->current_timer() >= 0) {
1114 const int current_timer = registry()->current_timer();
1115 def << context->macro_if("LIBINT2_CPLUSPLUS_STD >= 2011");
1116 if (current_timer > 0)
1117 def << "inteval->timers->stop(" << current_timer << ");" << std::endl;
1118 def << "inteval->timers->start(" << current_timer+1 << ");" << std::endl;
1119 def << context->macro_endif();
1120 }
1121
1122 //
1123 // need to zero out space for all missing prerequisites -- prereq evaluator will accumulate into that space
1124 //
1125
1126 // get prerequisites
1127 PrerequisitesExtractor pe;
1128 this->foreach(pe);
1129 // merge their blocks (likely into one -- allocate_mem should have taken care of that)
1130 SafePtr<MemBlockSet> targetblks = to_memoryblks(pe.vertices);
1131 merge(*targetblks);
1132
1133 // zero each one out
1134 for(MemBlockSet::iterator
1135 b = targetblks->begin();
1136 b != targetblks->end();
1137 ++b) {
1138 const size s = b->size();
1139 SafePtr<CTimeEntity<int> > bdim(new CTimeEntity<int>(s));
1140
1141 SafePtr<Entity> bvecdim;
1142 if (!dims->vecdim_is_static()) {
1143 SafePtr< RTimeEntity<EntityTypes::Int> > vecdim = dynamic_pointer_cast<RTimeEntity<EntityTypes::Int>,Entity>(dims->vecdim());
1144 bvecdim = vecdim * bdim;
1145 }
1146 else {
1147 SafePtr< CTimeEntity<int> > vecdim = dynamic_pointer_cast<CTimeEntity<int>,Entity>(dims->vecdim());
1148 bvecdim = vecdim * bdim;
1149 }
1150 def << "_libint2_static_api_bzero_short_(" << registry()->stack_name() << "+"
1151 << b->address() << "*" << dims->vecdim()->id() << "," << bvecdim->id() << ")" << endl;
1152 }
1153
1154 // and call the prereq evaluator
1155 SafePtr<Entity> zero(new CTimeEntity<int>(0));
1156 SafePtr<Entity> contr_depth(new RTimeEntity<EntityTypes::Int>("contrdepth"));
1157 std::string contr_index("c");
1158 SafePtr<ForLoop> contr_loop(new ForLoop(context,contr_index,contr_depth,zero));
1159
1160 def << context->decldef("const int","contrdepth","inteval->contrdepth");
1161 def << contr_loop->open();
1162 def << func_prereq_name << "(inteval+c, "
1163 << registry()->stack_name()
1164 << ")" << context->end_of_stat() << endl;
1165 def << contr_loop->close() << endl;
1166
1167 // if profiling is on, start the next timer, after stopping the current timer (if possible)
1168 if (registry()->current_timer() >= 0) {
1169 const int current_timer = registry()->current_timer();
1170 def << context->macro_if("LIBINT2_CPLUSPLUS_STD >= 2011");
1171 def << "inteval->timers->stop(" << current_timer+1 << ");" << std::endl;
1172 if (current_timer > 0)
1173 def << "inteval->timers->start(" << current_timer << ");" << std::endl;
1174 def << context->macro_endif();
1175 }
1176
1177 }
1178
1179 // if profiling=on and the current timer is outermost, start it
1180 // if this is not outermost timer, it has been started outside of this function
1181 if (registry()->current_timer() == 0) {
1182 def << context->macro_if("LIBINT2_CPLUSPLUS_STD >= 2011");
1183 def << "inteval->timers->start(0);" << std::endl;
1184 def << context->macro_endif();
1185 }
1186
1187 // now print out the code for this graph
1188 print_def(context,def,dims,args);
1189
1190 // if profiling=on and the current timer is outermost, stop it
1191 // if this is not outermost timer, it is managed outside of this function
1192 if (registry()->current_timer() == 0) {
1193 def << context->macro_if("LIBINT2_CPLUSPLUS_STD >= 2011");
1194 def << "inteval->timers->stop(0);" << std::endl;
1195 def << context->macro_endif();
1196 }
1197
1198 def << context->close_block() << endl;
1199 def << context->code_postfix();
1200 }
1201
allocate_mem(const SafePtr<MemoryManager> & memman,const SafePtr<ImplicitDimensions> & dims,unsigned int min_size_to_alloc)1202 void DirectedGraph::allocate_mem(const SafePtr<MemoryManager>& memman,
1203 const SafePtr<ImplicitDimensions>& dims,
1204 unsigned int min_size_to_alloc)
1205 {
1206 // NOTE does this belong here?
1207 // First, reset tag counters
1208 prepare_to_traverse();
1209
1210 struct TargetAllocator {
1211 typedef DirectedGraph::targets::const_iterator target_citer;
1212 typedef DirectedGraph::targets::iterator target_iter;
1213 typedef DirectedGraph::size sz;
1214 typedef DirectedGraph::address address;
1215
1216 const DirectedGraph::targets& targets_;
1217 const SafePtr<MemoryManager>& memman_;
1218 bool all_targets_;
1219 sz size_;
1220
1221 TargetAllocator(const DirectedGraph::targets& t,
1222 const SafePtr<MemoryManager>& mm,
1223 bool all_targets) :
1224 targets_(t),
1225 memman_(mm),
1226 all_targets_(all_targets)
1227 {
1228 // compute the aggregate size of all targets
1229 target_citer end = targets_.end();
1230 size_ = 0;
1231 for(target_citer t=targets_.begin(); t!=end; ++t) {
1232 const ver_ptr& tptr = vertex_ptr(*t);
1233 if (all_targets_ ||
1234 (!tptr->symbol_set() &&
1235 !tptr->address_set()
1236 )
1237 ) {
1238 size_ += (tptr)->size();
1239 }
1240 }
1241 }
1242
1243 sz size() const {return size_;}
1244
1245 void allocate() {
1246 for(target_citer v=targets_.begin(); v!=targets_.end(); ++v) {
1247 const ver_ptr& vptr = vertex_ptr(*v);
1248 if (all_targets_ ||
1249 (!vptr->symbol_set() &&
1250 !vptr->address_set()
1251 )
1252 ) {
1253 const MemoryManager::Address address = memman_->alloc(vptr->size());
1254 // if the vertex is just an alias, pass the address on
1255 if (vptr->refers_to_another())
1256 (*(vptr->first_exit_arc()))->dest()->set_address(address);
1257 else
1258 vptr->set_address(address);
1259 }
1260 }
1261 }
1262 };
1263
1264 //
1265 // First, allocate all prerequisites that are not precomputed
1266 // since they will be computed before the evaluation of this graph
1267 // they will be targets of the previous computation and will be
1268 // at the beginning of the previous stack. Preallocate them here.
1269 //
1270 if (this->missing_prerequisites()) {
1271 PrerequisitesExtractor pe;
1272 this->foreach(pe);
1273 targets prereqs(pe.vertices.size());
1274 std::copy(pe.vertices.begin(), pe.vertices.end(), prereqs.begin());
1275 // prereqs will be put on the prereq DirectedGraph in the same order
1276 // so use the same allocation mechanism here as for targets
1277 const bool all_targets = true;
1278 TargetAllocator ta(prereqs, memman, all_targets);
1279 ta.allocate();
1280 }
1281
1282 //
1283 // If need to accumulate targets, special events must happen here.
1284 //
1285 // NOTES on how to handle accumulation
1286 // 1) if all targets are unrolled then need to identify which integrals are part of target sets and use += instead of =
1287 // 2) if no targets are unrolled then allocate extra space for the target quartets and, after code has been generated,
1288 // accumulate target sets into those
1289 // EXCEPTION non new space for integral sets is needed if they are decontracted
1290 // 3) if some targets are not unrolled then still need the extra space. The targets which were not unrolled should be handled
1291 // as usual, i.e. not accumulated -- accumulation happens at the end
1292 if (registry()->accumulate_targets()) {
1293
1294 // need extra buffers for targets if some are not unrolled ( and not decontracted)
1295 //const bool need_copies_of_targets = nonunrolled_targets(targets_);
1296 const bool need_copies_of_targets = std::find_if(targets_.begin(),
1297 targets_.end(),
1298 [](SafePtr<DGVertex> i){
1299 return !DecontractedIntegralSet()(i) && !UnrolledIntegralSet()(i);
1300 }) != targets_.end();
1301
1302 iregistry()->accumulate_targets_directly(!need_copies_of_targets);
1303
1304 if (need_copies_of_targets) {
1305
1306 const bool all_targets = true;
1307 TargetAllocator ta(targets_, memman, all_targets);
1308 const size size_of_targets = ta.size();
1309 iregistry()->size_of_target_accum(size_of_targets);
1310
1311 // allocate every target accumulator manually
1312 const address targets_buffer = memman->alloc(size_of_targets);
1313 address curr_ptr = targets_buffer;
1314 for(target_citer t=targets_.begin(); t!=targets_.end(); ++t) {
1315 const ver_ptr& tptr = vertex_ptr(*t);
1316 target_accums_.push_back(curr_ptr);
1317 curr_ptr += (tptr)->size();
1318 }
1319
1320 } // need copies of targets
1321 } // need to accumulate targets
1322
1323 // Second, MUST allocate space for all targets whose symbols are not set explicitly
1324 // If a symbol is set means the object is not on stack (e.g. if location of target
1325 // is passed as an argument to set-level function)
1326 // This code ensures that target quartets are persistent, i.e. never overwritten, and can be accumulated into
1327 {
1328 const bool all_targets = false; // only targets without symbols
1329 TargetAllocator ta(targets_, memman, all_targets);
1330 ta.allocate();
1331 }
1332
1333 //
1334 // How memory management happens:
1335 // Go through the traversal order and at each step tag every child
1336 // Once a child receives same number of tags as the number of parents,
1337 // it can be deallocated
1338 //
1339 SafePtr<DGVertex> vertex = first_to_compute_;
1340 do {
1341 SafePtr<DGArcRR> arcrr;
1342 // memory only needs to be managed for some quantities:
1343 // this conditional decides whether this vertex is on the stack
1344 bool need_to_allocate = true;
1345
1346 // If symbol is set then the object is not on stack
1347 need_to_allocate &= not vertex->symbol_set();
1348
1349 // if address is already set, no need to manage
1350 need_to_allocate &= not vertex->address_set();
1351
1352 // precomputed objects don't go on stack
1353 need_to_allocate &= not vertex->precomputed();
1354
1355 // don't allocate on stack if smaller than or equal to min_size_to_alloc
1356 // two exceptions, however:
1357 // 1) it's a target
1358 // 2) it's an unrolled integral set whose members are not precomputed quantities
1359 // typically integral sets of size 1 are precomputed and don't need to be on stack,
1360 // however if they are not, the integral will need to be stored somewhere and the rule for
1361 // assigning code symbols to members of unrolled integral sets requires the integral set
1362 // to have an address assigned
1363 need_to_allocate &= ( vertex->size() > min_size_to_alloc ||
1364 vertex->is_a_target() ||
1365 (vertex->size() == vertex->num_exit_arcs() &&
1366 ( (arcrr = dynamic_pointer_cast<DGArcRR,DGArc>(*(vertex->first_exit_arc()))) != 0 ?
1367 dynamic_pointer_cast<IntegralSet_to_Integrals_base,RecurrenceRelation>(arcrr->rr()) != 0 :
1368 false ) &&
1369 !(*(vertex->first_exit_arc()))->dest()->precomputed()
1370 )
1371 );
1372
1373 // if this is a member of unrolled integral set whose address or symbol is set, no need to allocate
1374 if (need_to_allocate && vertex->num_entry_arcs() != 0) {
1375 arcrr = dynamic_pointer_cast<DGArcRR,DGArc>(*(vertex->first_entry_arc()));
1376 if (arcrr) {
1377 if (dynamic_pointer_cast<IntegralSet_to_Integrals_base,RecurrenceRelation>(arcrr->rr()) != 0) {
1378 if (arcrr->orig()->symbol_set() || arcrr->orig()->address_set())
1379 need_to_allocate = false;
1380 }
1381 }
1382 }
1383
1384 if (need_to_allocate) {
1385 MemoryManager::Address addr = memman->alloc(vertex->size());
1386 vertex->set_address(addr);
1387 #if DEBUG
1388 cout << "allocated " << vertex->description() << " at " << addr << " (size=" << vertex->size() << ")" << endl;
1389 #endif
1390
1391 typedef DGVertex::ArcSetType::const_iterator aciter;
1392 const aciter abegin = vertex->first_exit_arc();
1393 const aciter aend = vertex->plast_exit_arc();
1394 // Verify that all entry arcs are DGArcDirect
1395 for(aciter a=abegin; a!=aend; ++a) {
1396 SafePtr<DGVertex> child = (*a)->dest();
1397 const unsigned int ntags = child->tag();
1398 // Do NOT deallocate if it's a target!
1399 if (ntags == child->num_entry_arcs() && child->address_set() && !child->is_a_target()) {
1400 memman->free(child->address());
1401 }
1402 }
1403 }
1404 vertex = vertex->postcalc();
1405 }while (vertex != 0);
1406 }
1407
1408 void
assign_symbols(const SafePtr<CodeContext> & context,const SafePtr<ImplicitDimensions> & dims)1409 DirectedGraph::assign_symbols(const SafePtr<CodeContext>& context, const SafePtr<ImplicitDimensions>& dims)
1410 {
1411 std::ostringstream os;
1412 const std::string null_str("");
1413 const std::string stack_name = registry()->stack_name();
1414 // There used to be a compiler/library/bug on OS X? The above failed... Valgrind under Linux shows no memory problems...
1415 //const std::string stack_name("stack");
1416
1417 // Generate the label for the rank of the low dimension
1418 std::string low_rank = dims->low_label();
1419 std::string veclen = dims->vecdim_label();
1420
1421 // First, set symbols for all vertices which have address assigned
1422 typedef vertices::const_iterator citer;
1423 typedef vertices::iterator iter;
1424 for(iter v=stack_.begin(); v!=stack_.end(); ++v) {
1425 const ver_ptr& vptr = vertex_ptr(*v);
1426 if (!(vptr)->symbol_set() && (vptr)->address_set()) {
1427 (vptr)->set_symbol(stack_symbol(context,(vptr)->address(),(vptr)->size(),low_rank,veclen,stack_name));
1428 }
1429 }
1430
1431 // Second, find all nodes which were unrolled using IntegralSet_to_Integrals:
1432 // 1) such nodes do not need symbols generated since they never appear in the code expicitly
1433 // 2) children of such nodes have symbols that depend on the parent's address
1434 // upstream such nodes of size one were aliased ("referred") to their children -- just skip these
1435 for(iter v=stack_.begin(); v!=stack_.end(); ++v) {
1436 const ver_ptr& vptr = vertex_ptr(*v);
1437 if ((vptr)->num_exit_arcs() == 0)
1438 continue;
1439 if (vptr->refers_to_another())
1440 continue;
1441 SafePtr<DGArc> arc = *((vptr)->first_exit_arc());
1442 SafePtr<DGArcRR> arc_rr = dynamic_pointer_cast<DGArcRR,DGArc>(arc);
1443 if (arc_rr == 0)
1444 continue;
1445 SafePtr<RecurrenceRelation> rr = arc_rr->rr();
1446 SafePtr<IntegralSet_to_Integrals_base> iset_to_i = dynamic_pointer_cast<IntegralSet_to_Integrals_base,RecurrenceRelation>(rr);
1447 if (iset_to_i == 0) {
1448 continue;
1449 }
1450 else {
1451 typedef DGVertex::ArcSetType::const_iterator aciter;
1452 const aciter abegin = (vptr)->first_exit_arc();
1453 const aciter aend = (vptr)->plast_exit_arc();
1454 unsigned int c = 0;
1455 // Verify that all entry arcs are DGArcDirect
1456 for(aciter a=abegin; a!=aend; ++a, ++c) {
1457 SafePtr<DGVertex> child = (*a)->dest();
1458 // If the child is precomputed and it's parent symbol is not set -- its symbol will be set as usual
1459 if (!child->precomputed() || (vptr)->symbol_set()) {
1460
1461 // check if symbol is already set ... this indicates interference with the logic upstream, it should be set here?
1462 if (child->symbol_set()) {
1463 std::cout << "WARNING: symbol for " << child->description() << " (unrolled integral) already set" << std::endl;
1464 continue;
1465 }
1466
1467 if ((vptr)->address_set()) {
1468 child->set_symbol(stack_symbol(context,(vptr)->address()+c,(vptr)->size(),low_rank,veclen,stack_name));
1469 }
1470 else {
1471 child->set_symbol(stack_symbol(context,c,(vptr)->size(),low_rank,veclen,(vptr)->symbol()));
1472 }
1473 }
1474 }
1475 (vptr)->refer_this_to((*((vptr)->first_exit_arc()))->dest());
1476 (vptr)->reset_symbol();
1477 }
1478 }
1479
1480 // then process all other symbols, EXCEPT operators
1481 for(iter v=stack_.begin(); v!=stack_.end(); ++v) {
1482 const ver_ptr& vptr = vertex_ptr(*v);
1483 #if DEBUG
1484 cout << "Trying to assign symbol to " << (vptr)->description() << endl;
1485 #endif
1486 if ((vptr)->symbol_set()) {
1487 #if DEBUG
1488 cout << "symbol already set to " << (vptr)->symbol() << endl;
1489 #endif
1490 continue;
1491 }
1492
1493 // test if the vertex is a static quantity, like a constant
1494 {
1495 typedef CTimeEntity<double> cdouble;
1496 SafePtr<cdouble> ptr_cast = dynamic_pointer_cast<cdouble,DGVertex>((vptr));
1497 if (ptr_cast) {
1498 (vptr)->set_symbol(ptr_cast->label());
1499 continue;
1500 }
1501 }
1502
1503 // test if the vertex is precomputed runtime quantity, like a geometric parameter
1504 if ((vptr)->precomputed()) {
1505 std::string symbol("inteval->");
1506 symbol += context->label_to_name((vptr)->label());
1507 symbol += "[vi]";
1508 (vptr)->set_symbol(symbol);
1509 continue;
1510 }
1511
1512 // test if the vertex is other runtime quantity
1513 {
1514 typedef RTimeEntity<double> cdouble;
1515 SafePtr<cdouble> ptr_cast = dynamic_pointer_cast<cdouble,DGVertex>((vptr));
1516 if (ptr_cast) {
1517 (vptr)->set_symbol(ptr_cast->label());
1518 continue;
1519 }
1520 }
1521
1522 // test if the vertex is some other data that was not allocated on stack
1523 #if 1
1524 if (vptr->size() == 1) {
1525 { // basic integral
1526 typedef IntegralSet<IncableBFSet> intset;
1527 SafePtr<intset> ptr_cast = dynamic_pointer_cast<intset,DGVertex>((vptr));
1528 if (ptr_cast) {
1529 (vptr)->set_symbol(context->unique_name<EntityTypes::FP>());
1530 continue;
1531 }
1532 }
1533 }
1534 #endif
1535
1536 } // done with everything BUT operators
1537
1538 // finally, process all operators (start with most recently added vertices since those are
1539 // much more likely to be on the bottom of the graph).
1540 typedef vertices::const_reverse_iterator criter;
1541 typedef vertices::reverse_iterator riter;
1542 for(riter v=stack_.rbegin(); v!=stack_.rend(); ++v) {
1543 ver_ptr& vptr = vertex_ptr(*v);
1544 if (vptr->symbol_set())
1545 continue;
1546 #if DEBUG
1547 cout << "Trying to assign symbol to operator " << (vptr)->description() << endl;
1548 #endif
1549 assign_oper_symbol(context,(vptr));
1550 }
1551
1552 }
1553
1554 void
assign_oper_symbol(const SafePtr<CodeContext> & context,SafePtr<DGVertex> & vertex)1555 DirectedGraph::assign_oper_symbol(const SafePtr<CodeContext>& context, SafePtr<DGVertex>& vertex)
1556 {
1557 // do nothing if the vertex has a symbol or is not an operator
1558 if (vertex->symbol_set())
1559 return;
1560
1561 {
1562 typedef AlgebraicOperator<DGVertex> oper;
1563 SafePtr<oper> ptr_cast = dynamic_pointer_cast<oper,DGVertex>(vertex);
1564 if (ptr_cast) {
1565 // is it in a subtree?
1566 const bool on_a_subtree = (vertex->subtree() != 0);
1567
1568 // If no -- it will be an automatic variable
1569 if (!on_a_subtree)
1570 vertex->set_symbol(context->unique_name<EntityTypes::FP>());
1571 // else assign symbols to left and right arguments
1572 else {
1573 typedef DGVertex::ArcSetType::const_iterator aciter;
1574 aciter arc = ptr_cast->first_exit_arc();
1575 SafePtr<DGVertex> left = (*arc)->dest(); ++arc;
1576 SafePtr<DGVertex> right = (*arc)->dest();
1577 assign_oper_symbol(context,left);
1578 assign_oper_symbol(context,right);
1579
1580 std::ostringstream oss;
1581 oss << "( " << left->symbol() << " ) "
1582 << ptr_cast->label()
1583 << " ( " << right->symbol() << " )";
1584 vertex->set_symbol(oss.str());
1585 }
1586 }
1587 }
1588 }
1589
1590
1591 namespace {
1592 /// returns how many times token appears in str
nfind(const std::string & str,const std::string & token)1593 unsigned int nfind(const std::string& str, const std::string& token)
1594 {
1595 unsigned int nfinds = 0;
1596 typedef std::string::size_type size_type;
1597 size_type current_pos = 0;
1598 while(1) {
1599 current_pos = str.find(token,current_pos);
1600 if (current_pos != std::string::npos) {
1601 ++nfinds;
1602 ++current_pos;
1603 }
1604 else
1605 return nfinds;
1606 }
1607 }
1608
1609 #define DO_NOT_COUNT_DIV 1
1610 /// Returns the number of FLOPs in an expression
nflops(const std::string & expr)1611 unsigned int nflops(const std::string& expr)
1612 {
1613 static const std::string mul(" * ");
1614 static const std::string div(" / ");
1615 static const std::string plus(" + ");
1616 static const std::string minus(" - ");
1617 const unsigned int nflops =
1618 nfind(expr,mul) +
1619 nfind(expr,plus) +
1620 nfind(expr,minus)
1621 #if !DO_NOT_COUNT_DIV
1622 + nfind(expr,div)
1623 #endif
1624 ;
1625 return nflops;
1626 }
1627 }
1628
1629 void
print_def(const SafePtr<CodeContext> & context,std::ostream & os,const SafePtr<ImplicitDimensions> & dims,const SafePtr<CodeSymbols> & args)1630 DirectedGraph::print_def(const SafePtr<CodeContext>& context, std::ostream& os,
1631 const SafePtr<ImplicitDimensions>& dims,
1632 const SafePtr<CodeSymbols>& args)
1633 {
1634 std::ostringstream oss;
1635 const std::string null_str("");
1636 SafePtr<Entity > ctimeconst_zero(new CTimeEntity<int>(0));
1637
1638 //
1639 // set stack ... if this function was given any arguments, the first is always stack
1640 //
1641 os << context->decldef(context->type_name<double* const>(),
1642 "stack",
1643 (args->n() >= 1) ? args->symbol(0) : registry()->stack_name());
1644
1645 const bool accumulate_targets_directly = registry()->accumulate_targets() && iregistry()->accumulate_targets_directly();
1646 const bool accumulate_targets_indirectly = registry()->accumulate_targets() && !accumulate_targets_directly;
1647
1648 //
1649 // To optimize accumulation and setting blocks to zero, need to merge maximally the targets (the accumulation area is one block)
1650 //
1651 SafePtr<MemBlockSet> targetblks;
1652 if (registry()->accumulate_targets()) {
1653 if (accumulate_targets_directly) {
1654 targetblks = to_memoryblks(targets_);
1655 merge(*targetblks);
1656 }
1657 else {
1658 targetblks = SafePtr<MemBlockSet>(new MemBlockSet);
1659 targetblks->push_back(MemBlock(0,iregistry()->size_of_target_accum(),false,SafePtr<MemBlock>(),SafePtr<MemBlock>()));
1660 }
1661 }
1662
1663 //
1664 // If accumulating integrals, check inteval's zero_out_targets. If set to 1 -- zero out accumulated targets
1665 //
1666 #if LIBINT_ACCUM_INTS
1667 if (registry()->accumulate_targets()) {
1668
1669 const bool vecdim_is_static = dims->vecdim_is_static();
1670
1671 os << "if (inteval->zero_out_targets) {" << std::endl;
1672
1673 typedef MemBlockSet::const_iterator citer;
1674 typedef MemBlockSet::iterator iter;
1675 const citer end = targetblks->end();
1676 for(iter b=targetblks->begin(); b!=end; ++b) {
1677
1678 size s = b->size();
1679 SafePtr<CTimeEntity<int> > bdim(new CTimeEntity<int>(s));
1680
1681 SafePtr<Entity> bvecdim;
1682 if (!vecdim_is_static) {
1683 SafePtr< RTimeEntity<EntityTypes::Int> > vecdim = dynamic_pointer_cast<RTimeEntity<EntityTypes::Int>,Entity>(dims->vecdim());
1684 bvecdim = vecdim * bdim;
1685 }
1686 else {
1687 SafePtr< CTimeEntity<int> > vecdim = dynamic_pointer_cast<CTimeEntity<int>,Entity>(dims->vecdim());
1688 bvecdim = vecdim * bdim;
1689 }
1690
1691 #if 0
1692 std::string loopvar("i");
1693 ForLoop loop(context,loopvar,bvecdim,ctimeconst_zero);
1694 os << loop.open();
1695 {
1696 ostringstream oss;
1697 oss << registry()->stack_name() << "[" << loopvar << "]";
1698 const std::string zero("0");
1699 os << context->assign(oss.str(),zero);
1700 }
1701 os << loop.close();
1702 #endif
1703 os << "_libint2_static_api_bzero_short_(" << registry()->stack_name() << "+"
1704 << b->address() << "*" << dims->vecdim()->id() << "," << bvecdim->id() << ")" << endl;
1705
1706 }
1707
1708 os << "inteval->zero_out_targets = 0;" << std::endl << "}" << std::endl;
1709 }
1710 #endif
1711
1712 std::string varname("hsi");
1713 SafePtr<ForLoop> hsi_loop(new ForLoop(context,varname,dims->high(),SafePtr<Entity>(new CTimeEntity<int>(0))));
1714 os << hsi_loop->open();
1715
1716 varname = "lsi";
1717 SafePtr<ForLoop> lsi_loop(new ForLoop(context,varname,dims->low(),SafePtr<Entity>(new CTimeEntity<int>(0))));
1718 os << lsi_loop->open();
1719
1720 // the vector loop is created outside of the body of the function if
1721 // 1) blockwise vectorization is requested
1722 // and
1723 // 2) this is a purely int-unit code, i.e. there are no explicit RRs on sets in the body
1724 // Otherwise, I create a dummy vector loop with the vector loop index set to 0
1725 const unsigned int max_vector_length = context->cparams()->max_vector_length();
1726 const bool vectorize = (max_vector_length != 1);
1727 const bool vectorize_by_line = context->cparams()->vectorize_by_line();
1728 const bool create_outer_vector_loop = !vectorize_by_line && !cannot_enclose_in_outer_vloop();
1729 varname = "vi";
1730 // outer vector loop
1731 SafePtr<ForLoop> outer_vloop;
1732 // vector loop for each code line
1733 SafePtr<ForLoop> line_vloop;
1734 if (create_outer_vector_loop) {
1735 SafePtr<ForLoop> tmp_vi_loop(new ForLoop(context,varname,dims->vecdim(),SafePtr<Entity>(new CTimeEntity<int>(0))));
1736 outer_vloop = tmp_vi_loop;
1737 }
1738 else {
1739 SafePtr<Entity> unit_dim(new CTimeEntity<int>(1));
1740 SafePtr<ForLoop> tmp_vi_loop(new ForLoop(context,varname,unit_dim,SafePtr<Entity>(new CTimeEntity<int>(0))));
1741 outer_vloop = tmp_vi_loop;
1742 SafePtr<ForLoop> tmp2_vi_loop(new ForLoop(context,varname,dims->vecdim(),SafePtr<Entity>(new CTimeEntity<int>(0))));
1743 line_vloop = tmp2_vi_loop;
1744 // note that both loops use same variable name -- standard C++ scoping rules allow it -- but the outer
1745 // loop will become a declaration of a constant variable
1746 }
1747 os << outer_vloop->open();
1748
1749 //
1750 // generate code for vertices
1751 //
1752 unsigned int nflops_total = 0;
1753 SafePtr<DGVertex> current_vertex = first_to_compute_;
1754 do {
1755
1756 // skip if already scheduled
1757 if (current_vertex->scheduled())
1758 goto next;
1759
1760 // skip if this is a dummy vertex (i.e. refers to someone else)
1761 if (current_vertex->refers_to_another())
1762 goto next;
1763
1764 // for every vertex that has a defined symbol, hence must be defined in code
1765 if (current_vertex->symbol_set()) {
1766
1767 // Type declaration if this is an automatic variable
1768 // ids for automatic variables cannot have characters '[' or ']'
1769 const std::string& symbol = current_vertex->symbol();
1770 if (symbol.find('[') == std::string::npos) {
1771 os << context->declare(context->type_name<double> (),
1772 current_vertex->symbol());
1773 #if CHECK_SAFETY
1774 current_vertex->declared(true);
1775 #endif
1776 }
1777
1778 // print algebraic expression
1779 {
1780 typedef AlgebraicOperator<DGVertex> oper_type;
1781 SafePtr<oper_type> oper_ptr =
1782 dynamic_pointer_cast<oper_type, DGVertex> (current_vertex);
1783 if (oper_ptr) {
1784
1785 // If this is an Integral in a target IntegralSet AND
1786 // can accumulate targets directly -- use '+=' instead of '='
1787 const bool accumulate_not_assign = accumulate_targets_directly
1788 && IntegralInTargetIntegralSet()(current_vertex);
1789
1790 typedef DGVertex::ArcSetType::const_iterator aciter;
1791 aciter a = oper_ptr->first_exit_arc();
1792 auto left_arg = (*a)->dest();
1793 ++a;
1794 auto right_arg = (*a)->dest();
1795
1796 if (context->comments_on()) {
1797
1798 oss.str(null_str);
1799 oss << current_vertex->label() << (accumulate_not_assign ? " += "
1800 : " = ")
1801 << left_arg->label() << oper_ptr->label() << right_arg->label();
1802 os << context->comment(oss.str()) << endl;
1803 }
1804
1805 // expression
1806 #if DEBUG
1807 cout << "Generating code for " << current_vertex->description() << endl;
1808 cout << " ptr = " << current_vertex << endl;
1809 cout << " left_arg = " << left_arg->description() << endl;
1810 cout << " right_arg = " << right_arg->description() << endl;
1811 #endif
1812
1813 // can we generate a composite instruction, like FMA?
1814 // - FMA: yes if:
1815 // o current_vertex is a multiply
1816 // o it has one parent and the parent is a +/- operator
1817 // o parent's other argument has been computed already
1818 bool generate_fma = false;
1819 SafePtr<oper_type> parent_oper_ptr;
1820 SafePtr<DGVertex> fma_other_arg;
1821 #if LIBINT_GENERATE_FMA
1822 {
1823 if (oper_ptr->type() == algebra::OperatorTypes::Times &&
1824 oper_ptr->num_entry_arcs() == 1) {
1825 parent_oper_ptr =
1826 dynamic_pointer_cast<oper_type, DGVertex> ( (*(oper_ptr->first_entry_arc()))->orig() );
1827 if (parent_oper_ptr != 0) {
1828 auto parent_oper_type = parent_oper_ptr->type();
1829 if (parent_oper_type == algebra::OperatorTypes::Plus ||
1830 parent_oper_type == algebra::OperatorTypes::Minus) {
1831 auto arg1 = (*(parent_oper_ptr->first_exit_arc()))->dest();
1832 auto arg2 = (*(++(parent_oper_ptr->first_exit_arc())))->dest();
1833 auto other_arg = (arg1 == current_vertex) ? arg2 : arg1;
1834 const bool other_ready = other_arg->num_exit_arcs() == 0 || other_arg->scheduled();
1835
1836 // can do fmadd if other_ready
1837 // can do fmsub if other_ready and it's arg2
1838 if (parent_oper_type == algebra::OperatorTypes::Plus)
1839 generate_fma = other_ready;
1840 else
1841 generate_fma = other_ready && (current_vertex == arg1);
1842 if (generate_fma) {
1843 //std::cout << context->comment("CAN GENERATE FMA!!!") << endl;
1844 fma_other_arg = other_arg;
1845
1846 #if DEBUG
1847 std::cout << "Generating FMA:" << std::endl;
1848 std::cout << "parent:\n";
1849 parent_oper_ptr->print(std::cout);
1850 std::cout << "current_vertex:\n";
1851 current_vertex->print(std::cout);
1852 std::cout << "other_vertex:\n";
1853 other_arg->print(std::cout);
1854 std::cout << "arg1:\n";
1855 arg1->print(std::cout);
1856 std::cout << "arg2:\n";
1857 arg2->print(std::cout);
1858 std::cout << "child1:\n";
1859 (*(parent_oper_ptr->first_exit_arc()))->dest()->print(std::cout);
1860 std::cout << "child2:\n";
1861 (*(++(parent_oper_ptr->first_exit_arc())))->dest()->print(std::cout);
1862 #endif
1863
1864
1865 // may need to declare the variable for the parent
1866 if (parent_oper_ptr->symbol_set()) {
1867
1868 // Type declaration if this is an automatic variable
1869 // ids for automatic variables cannot have characters '[' or ']'
1870 const std::string& symbol = parent_oper_ptr->symbol();
1871 if (symbol.find('[') == std::string::npos) {
1872 os << context->declare(context->type_name<double> (),
1873 parent_oper_ptr->symbol());
1874 #if CHECK_SAFETY
1875 parent_oper_ptr->declared(true);
1876 #endif
1877 }
1878 }
1879
1880 }
1881 }
1882 }
1883 }
1884 }
1885 #endif // LIBINT_GENERATE_FMA=1
1886
1887 // convert symbols to their vector form if needed
1888 std::string curr_symbol = current_vertex->symbol();
1889 std::string left_symbol = left_arg->symbol();
1890 std::string right_symbol = right_arg->symbol();
1891 std::string parent_symbol = generate_fma ? parent_oper_ptr->symbol() : "";
1892 std::string fma_other_arg_symbol = generate_fma ? fma_other_arg->symbol() : "";
1893 if (vectorize) {
1894 curr_symbol = to_vector_symbol(current_vertex);
1895 left_symbol = to_vector_symbol(left_arg);
1896 right_symbol = to_vector_symbol(right_arg);
1897 if (generate_fma) {
1898 parent_symbol = to_vector_symbol(parent_oper_ptr);
1899 fma_other_arg_symbol = to_vector_symbol(fma_other_arg);
1900 }
1901 }
1902 #if CHECK_SAFETY && 0
1903 bool left_not_declared = left_arg->need_to_compute() && !left_arg->declared();
1904 bool right_not_declared = right_arg->need_to_compute() && !right_arg->declared();
1905 if (left_not_declared || right_not_declared) {
1906 std::cout << "Current vertex:" << endl; current_vertex->print(std::cout);
1907 std::cout << "left arg :" << endl; left_arg->print(std::cout);
1908 std::cout << "right arg :" << endl; right_arg->print(std::cout);
1909 if (left_not_declared) throw ProgrammingError("DirectedGraph::print_def() -- left_arg not declared");
1910 if (right_not_declared) throw ProgrammingError("DirectedGraph::print_def() -- right_arg not declared");
1911 }
1912 #endif
1913
1914 if (vectorize_by_line)
1915 os << line_vloop->open();
1916 // the statement that does the work
1917 {
1918 if (accumulate_not_assign) {
1919
1920 if (generate_fma) {
1921 os << context->accumulate_ternary_expr(parent_symbol,
1922 left_symbol,
1923 oper_ptr->label(),
1924 right_symbol,
1925 parent_oper_ptr->label(),
1926 fma_other_arg_symbol
1927 );
1928 nflops_total += 2; // extra flop due to accumulation + extra flop due to FMA
1929 }
1930 else {
1931 os << context->accumulate_binary_expr(curr_symbol, left_symbol,
1932 oper_ptr->label(),
1933 right_symbol);
1934 nflops_total += 1; // extra flop due to accumulation
1935 }
1936
1937 } else { // assign, not accumulate
1938 if (generate_fma) {
1939 os << context->assign_ternary_expr(parent_symbol,
1940 left_symbol,
1941 oper_ptr->label(),
1942 right_symbol,
1943 parent_oper_ptr->label(),
1944 fma_other_arg_symbol
1945 );
1946 nflops_total += 1; // extra flop due to FMA
1947 }
1948 else {
1949 os
1950 << context->assign_binary_expr(curr_symbol, left_symbol,
1951 oper_ptr->label(),
1952 right_symbol);
1953 }
1954 }
1955
1956 nflops_total += (1 + nflops(left_symbol) + nflops(right_symbol));
1957 }
1958 if (vectorize_by_line)
1959 os << line_vloop->close();
1960
1961 // if produced FMA, do not forget to mark the parent scheduled
1962 if (generate_fma) {
1963 parent_oper_ptr->schedule();
1964 }
1965
1966 goto next;
1967 }
1968 }
1969
1970 // print simple assignment statement
1971 if (current_vertex->num_exit_arcs() == 1) {
1972 typedef DGArcDirect arc_type;
1973 SafePtr<arc_type>
1974 arc_ptr =
1975 dynamic_pointer_cast<arc_type, DGArc> (
1976 *(current_vertex->first_exit_arc()));
1977 if (arc_ptr) {
1978
1979 #if CHECK_SAFETY
1980 current_vertex->declared(true);
1981
1982 SafePtr<DGVertex> rhs_arg = arc_ptr->dest();
1983 if (!rhs_arg->declared() && rhs_arg->need_to_compute()) {
1984 std::cout << "Current vertex:" << endl; current_vertex->print(std::cout);
1985 std::cout << "rhs_arg :" << endl; rhs_arg->print(std::cout);
1986 throw ProgrammingError("DirectedGraph::print_def() -- rhs_arg not declared");
1987 }
1988 #endif
1989
1990 // If this is an Integral in a target IntegralSet AND
1991 // can accumulate targets directly -- use '+=' instead of '='
1992 const bool accumulate_not_assign = accumulate_targets_directly
1993 && IntegralInTargetIntegralSet()(current_vertex);
1994
1995 if (context->comments_on()) {
1996 oss.str(null_str);
1997 oss << current_vertex->label() << (accumulate_not_assign ? " += "
1998 : " = ")
1999 << arc_ptr->dest()->label();
2000 os << context->comment(oss.str()) << endl;
2001 }
2002
2003 // convert symbols to their vector form if needed
2004 std::string curr_symbol = current_vertex->symbol();
2005 std::string rhs_symbol = arc_ptr->dest()->symbol();
2006 if (vectorize) {
2007 curr_symbol = to_vector_symbol(current_vertex);
2008 rhs_symbol = to_vector_symbol(arc_ptr->dest());
2009 }
2010
2011 if (vectorize_by_line)
2012 os << line_vloop->open();
2013 if (accumulate_not_assign) {
2014 os << context->accumulate(curr_symbol, rhs_symbol);
2015 nflops_total += nflops(rhs_symbol) + 1; // +1 due to +=
2016 } else {
2017 os << context->assign(curr_symbol, rhs_symbol);
2018 nflops_total += nflops(rhs_symbol);
2019 }
2020 if (vectorize_by_line)
2021 os << line_vloop->close();
2022
2023 goto next;
2024 }
2025 }
2026
2027 // print out a recurrence relation
2028 if (current_vertex->num_exit_arcs() != 0) {
2029 // printing a recurrence relation
2030 //std::cout << "DirectedGraph::print_def(): a RR making " << current_vertex->description() << std::endl;
2031 typedef DGArcRR arc_type;
2032 SafePtr<arc_type>
2033 arc_ptr =
2034 dynamic_pointer_cast<arc_type, DGArc> (
2035 *(current_vertex->first_exit_arc()));
2036 if (arc_ptr) {
2037 SafePtr<RecurrenceRelation> rr = arc_ptr->rr();
2038 os << rr->spfunction_call(context, dims);
2039 nflops_total += rr->nflops();
2040
2041 goto next;
2042 }
2043 }
2044
2045 #if 0
2046 {
2047 current_vertex->print(std::cout);
2048 throw std::runtime_error(
2049 "DirectedGraph::print_def() -- cannot handle this vertex yet");
2050 }
2051 #endif
2052 }
2053
2054 next:
2055 current_vertex->schedule();
2056 current_vertex = current_vertex->postcalc();
2057
2058 } while (current_vertex != 0);
2059
2060 os << outer_vloop->close();
2061 os << lsi_loop->close();
2062 os << hsi_loop->close();
2063
2064 //
2065 // Accumulate targets
2066 //
2067 if (accumulate_targets_indirectly) {
2068 os << context->comment("Accumulate target integral sets") << std::endl;
2069 //const std::string& stack_name = registry()->stack_name();
2070 const bool vecdim_is_static = dims->vecdim_is_static();
2071 #if 0
2072 unsigned int vecdim_rank;
2073 if (vecdim_is_static) {
2074 SafePtr<CTimeEntity<int> > cptr = dynamic_pointer_cast<CTimeEntity<int> ,Entity >(dims->vecdim());
2075 vecdim_rank = cptr->value();
2076 }
2077 const std::string times_vecdim("*" + dims->vecdim_label());
2078 #endif
2079
2080 // Loop over targets
2081 unsigned int curr_target = 0;
2082 for(target_iter t=targets_.begin(); t!=targets_.end(); ++t, ++curr_target) {
2083 const ver_ptr& tptr = vertex_ptr(*t);
2084
2085 size s = (tptr)->size();
2086 SafePtr<CTimeEntity<int> > bdim(new CTimeEntity<int>(s));
2087
2088 SafePtr<Entity> bvecdim;
2089 if (!vecdim_is_static) {
2090 SafePtr< RTimeEntity<EntityTypes::Int> > vecdim = dynamic_pointer_cast<RTimeEntity<EntityTypes::Int>,Entity>(dims->vecdim());
2091 bvecdim = vecdim * bdim;
2092 }
2093 else {
2094 SafePtr< CTimeEntity<int> > vecdim = dynamic_pointer_cast<CTimeEntity<int>,Entity>(dims->vecdim());
2095 bvecdim = vecdim * bdim;
2096 }
2097
2098 // For now write an explicit loop for each target. In the future should:
2099 // 1) check if all computed targets are adjacent (accumulated targets are adjacent by the virtue of the allocation mechanism)
2100 // 2) check the sizes and insert optimized calls, if possible
2101
2102 // form a single loop over integrals and vector dimension
2103 // NOTE: single loop suffices because if outer/inner strides are not 1 this block of code should not be executed
2104
2105 #if 0
2106 SafePtr<Entity> loopmax;
2107 if (vecdim_is_static) {
2108 const int loopmax_value = s*vecdim_rank;
2109 loopmax = SafePtr<Entity>(new CTimeEntity<int>(loopmax_value));
2110 }
2111 else {
2112 ostringstream oss; oss << s << times_vecdim;
2113 loopmax = SafePtr<Entity>(new RTimeEntity<EntityTypes::Int>(oss.str()));
2114 }
2115
2116 std::string loopvar("i");
2117 ForLoop loop(context,loopvar,loopmax,ctimeconst_zero);
2118 os << loop.open();
2119 std::string acctarget;
2120 {
2121 ostringstream oss;
2122 oss << stack_name << "[" << target_accums_[curr_target] << times_vecdim << "+" << loopvar << "]";
2123 acctarget = oss.str();
2124 }
2125 std::string target;
2126 {
2127 ostringstream oss;
2128 oss << stack_name << "[" << (tptr)->address() << times_vecdim << "+" << loopvar << "]";
2129 target = oss.str();
2130 }
2131 os << context->accumulate(acctarget,target);
2132 os << loop.close();
2133 #endif
2134 os << "_libint2_static_api_inc1_short_("
2135 << registry()->stack_name() << "+" << target_accums_[curr_target] << "*" << dims->vecdim()->id() << ","
2136 << registry()->stack_name() << "+" << (tptr)->address() << "*" << dims->vecdim()->id() << ","
2137 << bvecdim->id() << ")" << endl;
2138
2139 nflops_total += s;
2140 }
2141 }
2142
2143 // Outside of loops stack symbols don't make sense, so we must define loop variables hsi, lsi, and vi to 0
2144 os << context->decldef(context->type_name<const int>(), "hsi", "0");
2145 os << context->decldef(context->type_name<const int>(), "lsi", "0");
2146 os << context->decldef(context->type_name<const int>(), "vi", "0");
2147
2148 //
2149 // Now pass back all targets through the inteval object, if needed.
2150 //
2151 if (registry()->return_targets()) {
2152 unsigned int curr_target = 0;
2153 for(target_iter t=targets_.begin(); t!=targets_.end(); ++t, ++curr_target) {
2154 const ver_ptr& tptr = vertex_ptr(*t);
2155 const std::string& symbol = (accumulate_targets_indirectly
2156 // is this correct? ???
2157 ? stack_symbol(context,target_accums_[curr_target],(tptr)->size(),dims->low_label(),dims->vecdim_label(),registry()->stack_name())
2158 : (tptr)->symbol());
2159 os << "inteval->targets[" << curr_target << "] = "
2160 << context->value_to_pointer(symbol) << context->end_of_stat() << endl;
2161 }
2162 }
2163
2164 // Print out the number of flops
2165 oss.str(null_str);
2166 oss << "Number of flops = " << nflops_total;
2167 os << context->comment(oss.str()) << endl;
2168
2169 if (context->cparams()->count_flops()) {
2170 oss.str(null_str);
2171 oss << nflops_total << " * " << dims->high_label() << " * "
2172 << dims->low_label() << " * "
2173 << dims->vecdim_label();
2174 os << context->assign_binary_expr("inteval->nflops[0]","inteval->nflops[0]","+",oss.str());
2175 }
2176
2177 }
2178
2179 void
update_func_names()2180 DirectedGraph::update_func_names()
2181 {
2182 // Loop over all vertices
2183 typedef vertices::const_iterator citer;
2184 typedef vertices::iterator iter;
2185 for(iter v=stack_.begin(); v!=stack_.end(); ++v) {
2186 const ver_ptr& vptr = vertex_ptr(*v);
2187 // for every vertex with children
2188 if ((vptr)->num_exit_arcs() > 0) {
2189 // if it must be computed using a RR
2190 SafePtr<DGArc> arc = *((vptr)->first_exit_arc());
2191 SafePtr<DGArcRR> arcrr = dynamic_pointer_cast<DGArcRR,DGArc>(arc);
2192 if (arcrr != 0) {
2193 SafePtr<RecurrenceRelation> rr = arcrr->rr();
2194 // and the RR is complex (i.e. likely to result in a function call)
2195 if (!rr->is_simple()) {
2196 // add it to the RRStack
2197 func_names_[rr->label()] = true;
2198 }
2199 }
2200 }
2201 }
2202 }
2203
2204 bool
cannot_enclose_in_outer_vloop() const2205 DirectedGraph::cannot_enclose_in_outer_vloop() const
2206 {
2207 SafePtr<DGVertex> current_vertex = first_to_compute_;
2208 do {
2209 const int nchildren = current_vertex->num_exit_arcs();
2210 if (nchildren > 0) {
2211 arc_ptr aptr = *(current_vertex->first_exit_arc());
2212 SafePtr<DGArcRR> aptr_cast = dynamic_pointer_cast<DGArcRR,arc>(aptr);
2213 // if this is a RR
2214 if (aptr_cast != 0) {
2215 // and a non-trivial one
2216 SafePtr<RecurrenceRelation> rr = aptr_cast->rr();
2217 if (!rr->is_simple()) {
2218 // if this is an Uncontract_Integral call, return false
2219 SafePtr<Uncontract_Integral_base> rr_ucb_ptr = dynamic_pointer_cast<Uncontract_Integral_base,RecurrenceRelation>(rr);
2220 if (rr_ucb_ptr)
2221 return false;
2222 else
2223 return true;
2224 }
2225 }
2226 }
2227 current_vertex = current_vertex->postcalc();
2228 } while (current_vertex != 0);
2229
2230 return false;
2231 }
2232
2233 void
find_subtrees()2234 DirectedGraph::find_subtrees()
2235 {
2236 // need to condense expressions?
2237 if (!registry()->condense_expr())
2238 return;
2239
2240 // Find subtrees by starting from the targets and moving down ...
2241 typedef vertices::const_iterator citer;
2242 typedef vertices::iterator iter;
2243 for(iter v=stack_.begin(); v!=stack_.end(); ++v) {
2244 const ver_ptr& vptr = vertex_ptr(*v);
2245 if ((vptr)->is_a_target() && (vptr)->num_entry_arcs() == 0) {
2246 find_subtrees_from(vptr);
2247 }
2248 }
2249 }
2250
2251 void
find_subtrees_from(const SafePtr<DGVertex> & v)2252 DirectedGraph::find_subtrees_from(const SafePtr<DGVertex>& v)
2253 {
2254 // is not on a subtree already
2255 if (!v->subtree()) {
2256
2257 bool useless_subtree = false;
2258
2259 //
2260 // Subtrees are useless in the following cases:
2261 // 1) root is computed via RRs, not explicitly
2262 //
2263 {
2264 if (v->num_exit_arcs() > 0) {
2265 SafePtr<DGArc> arc = *(v->first_exit_arc());
2266 SafePtr<DGArcRR> arc_rr = dynamic_pointer_cast<DGArcRR,DGArc>(arc);
2267 if (arc_rr)
2268 useless_subtree = true;
2269 }
2270 }
2271
2272 // create subtree
2273 if (!useless_subtree) {
2274 SafePtr<DRTree> stree = DRTree::CreateRootedAt(v);
2275
2276 // Remove all trivial subtrees
2277 if (stree) {
2278 #if DISABLE_SUBTREES
2279 if (stree->nvertices() >= 0) {
2280 stree->detach();
2281 }
2282 #else
2283 if (stree->nvertices() < 3) {
2284 stree->detach();
2285 }
2286 #endif
2287 }
2288 }
2289
2290 // move on to children
2291 typedef DGVertex::ArcSetType::const_iterator aciter;
2292 const aciter abegin = v->first_exit_arc();
2293 const aciter aend = v->plast_exit_arc();
2294 for(aciter a=abegin; a!=aend; ++a) {
2295 find_subtrees_from((*a)->dest());
2296 }
2297 }
2298 }
2299
2300 namespace {
2301 struct PrerequisiteNotComputed {
operator ()__anonad020a590a11::PrerequisiteNotComputed2302 bool operator()(const DirectedGraph::vertices::value_type& v) {
2303 return !v.second->precomputed() && v.second->num_exit_arcs() == 0;
2304 }
2305 };
2306 }
2307
2308 /// return true if there are vertices with 0 children but not pre-computed
2309 bool
missing_prerequisites() const2310 DirectedGraph::missing_prerequisites() const {
2311 bool missing_prereqs = false;
2312 if (!this->registry()->ignore_missing_prereqs()) {
2313 #if 0
2314 missing_prereqs =
2315 find_if(this->stack_.begin(), this->stack_.end(), [](const vertices::value_type& v) {
2316 return v.second->precomputed() == false && v.second->num_exit_arcs() == 0;
2317 }) != this->stack_.end();
2318 #else
2319 PrerequisiteNotComputed pred;
2320 missing_prereqs =
2321 std::find_if(stack_.begin(), stack_.end(), pred) != stack_.end();
2322 #endif
2323 }
2324 return missing_prereqs;
2325 }
2326
2327 ////
2328
2329 namespace libint2 {
2330
2331 namespace {
2332 #if USE_ASSOCCONTAINER_BASED_DIRECTEDGRAPH
2333 typedef DirectedGraph::targets::value_type value_type;
2334 struct __NotUnrolledIntegralSet : public std::unary_function<const value_type&,bool> {
operator ()libint2::__anonad020a590b11::__NotUnrolledIntegralSet2335 bool operator()(const value_type& v) {
2336 return NotUnrolledIntegralSet()(v);
2337 }
2338 };
2339 #endif
2340 };
2341
2342 bool
nonunrolled_targets(const DirectedGraph::targets & targets)2343 nonunrolled_targets(const DirectedGraph::targets& targets) {
2344 typedef DirectedGraph::target_citer citer;
2345 citer end = targets.end();
2346 #if USE_ASSOCCONTAINER_BASED_DIRECTEDGRAPH
2347 if (end != find_if(targets.begin(),end,__NotUnrolledIntegralSet()))
2348 #else
2349 if (end != find_if(targets.begin(),end,NotUnrolledIntegralSet()))
2350 #endif
2351 return true;
2352 else
2353 return false;
2354 }
2355
2356 void
extract_symbols(const SafePtr<DirectedGraph> & dg)2357 extract_symbols(const SafePtr<DirectedGraph>& dg)
2358 {
2359 LibraryTaskManager& taskmgr = LibraryTaskManager::Instance();
2360 // symbol extractor
2361 {
2362 SafePtr<ExtractExternSymbols> extractor(new ExtractExternSymbols);
2363 dg->foreach(*extractor);
2364 const ExtractExternSymbols::Symbols& symbols = extractor->symbols();
2365 // pass on to the symbol maintainer of the current task
2366 taskmgr.current().symbols()->add(symbols);
2367 #if DEBUG
2368 // print out the symbols
2369 std::cout << "Recovered symbols from DirectedGraph for " << dg << std::endl;
2370 typedef ExtractExternSymbols::Symbols::const_iterator citer;
2371 citer end = symbols.end();
2372 for(citer t=symbols.begin(); t!=end; ++t)
2373 std::cout << *t << std::endl;
2374 #endif
2375 }
2376 // RR extractor
2377 {
2378 SafePtr<ExtractRR> extractor(new ExtractRR);
2379 dg->foreach(*extractor);
2380 const ExtractRR::RRList& rrlist = extractor->rrlist();
2381 // pass on to the symbol maintainer of the current task
2382 taskmgr.current().symbols()->add(rrlist);
2383 }
2384 }
2385
2386 //////////////////////
2387 void
operator ()(const SafePtr<DGVertex> & v)2388 PrerequisitesExtractor::operator()(const SafePtr<DGVertex>& v) {
2389 #if DEBUG
2390 std::cout << "PrerequisitesExtractor: considering " << v->description() << std::endl;
2391 v->print(std::cout);
2392 #endif
2393 if (!v->precomputed() &&
2394 v->num_exit_arcs() == 0) {
2395 #if DEBUG
2396 std::cout << "PrerequisitesExtractor: found candidate " << v->description() << std::endl;
2397 #endif
2398
2399 #define EXTRACTINTEGRALSETS 1
2400 #if EXTRACTINTEGRALSETS
2401 // if this is an integral that was a member of a shell set, add the parent set to the prerequisite level
2402 bool member_of_shellset = false;
2403 SafePtr<DGVertex> parent_shellset;
2404 typedef DGVertex::ArcSetType::const_iterator citer;
2405 for(citer i=v->entry_arcs().begin(); i != v->entry_arcs().end() && !member_of_shellset; ++i) {
2406 const SafePtr<DGArc>& arc = *i;
2407 SafePtr<DGArcRR> arc_cast = dynamic_pointer_cast<DGArcRR,DGArc>(arc);
2408 if (arc_cast) {
2409 SafePtr<RecurrenceRelation> rr = arc_cast->rr();
2410 SafePtr<IntegralSet_to_Integrals_base> rr_cast = dynamic_pointer_cast<IntegralSet_to_Integrals_base,RecurrenceRelation>(rr);
2411 if (rr_cast) {
2412 member_of_shellset = true;
2413 parent_shellset = rr->rr_target();
2414 }
2415 }
2416 }
2417 if (member_of_shellset) {
2418 #if DEBUG
2419 std::cout << "PrerequisitesExtractor: " << v->description() << " is a member of a shell set, will add that instead"<< std::endl;
2420 #endif
2421 if ( vertices.end() == find(vertices.begin(), vertices.end(), parent_shellset) ) {
2422 vertices.push_back(parent_shellset);
2423 #if DEBUG
2424 std::cout << "PrerequisitesExtractor: extracted " << parent_shellset->description() << std::endl;
2425 #endif
2426 }
2427 else {
2428 #if DEBUG
2429 std::cout << "PrerequisitesExtractor: candidate's parent already extracted" << std::endl;
2430 #endif
2431 }
2432 }
2433 else {
2434 vertices.push_back(v);
2435 #if DEBUG
2436 std::cout << "PrerequisitesExtractor: extracted " << v->description() << std::endl;
2437 #endif
2438 }
2439
2440 #else
2441
2442 vertices.push_back(v);
2443 #if DEBUG
2444 std::cout << "PrerequisitesExtractor: extracted " << v->description() << std::endl;
2445 #endif
2446
2447 #endif
2448 }
2449 }
2450 //////////////////////
2451 void
operator ()(const SafePtr<DGVertex> & v)2452 VertexPrinter::operator()(const SafePtr<DGVertex>& v) {
2453 os << "VertexPrinter: " << v->description() << std::endl;
2454 #if DEBUG
2455 v->print(os);
2456 #endif
2457 }
2458
2459 };
2460