1 /* 2 * Copyright (C) 2004-2021 Edward F. Valeev 3 * 4 * This file is part of Libint. 5 * 6 * Libint is free software: you can redistribute it and/or modify 7 * it under the terms of the GNU General Public License as published by 8 * the Free Software Foundation, either version 3 of the License, or 9 * (at your option) any later version. 10 * 11 * Libint is distributed in the hope that it will be useful, 12 * but WITHOUT ANY WARRANTY; without even the implied warranty of 13 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 14 * GNU General Public License for more details. 15 * 16 * You should have received a copy of the GNU General Public License 17 * along with Libint. If not, see <http://www.gnu.org/licenses/>. 18 * 19 */ 20 21 #ifndef _libint2_src_bin_libint_dg_h_ 22 #define _libint2_src_bin_libint_dg_h_ 23 24 #include <iostream> 25 #include <string> 26 #include <list> 27 #include <map> 28 #include <vector> 29 #include <deque> 30 #include <algorithm> 31 #include <stdexcept> 32 #include <cassert> 33 34 #include <global_macros.h> 35 #include <exception.h> 36 #include <smart_ptr.h> 37 #include <key.h> 38 #include <dgvertex.h> 39 40 namespace libint2 { 41 42 // class DGVertex; 43 class DGArc; 44 template <class T> class DGArcRel; 45 template <class T> class AlgebraicOperator; 46 class Strategy; 47 class Tactic; 48 class CodeContext; 49 class MemoryManager; 50 class ImplicitDimensions; 51 class CodeSymbols; 52 class GraphRegistry; 53 class InternalGraphRegistry; 54 55 /** DirectedGraph is an implementation of a directed graph 56 composed of vertices represented by DGVertex objects. Most important operations 57 will assume that this is a DAG, i.e. there are no directed cycles. 58 59 \note The objects are allocated on free store and the graph is implemented as 60 an object of type 'vertices'. 61 */ 62 63 class DirectedGraph : public EnableSafePtrFromThis<DirectedGraph> { 64 public: 65 typedef DGVertex vertex; 66 typedef DGArc arc; 67 typedef SafePtr<DGVertex> ver_ptr; 68 typedef SafePtr<DGArc> arc_ptr; 69 typedef DGVertexKey key_type; 70 typedef std::multimap<key_type,ver_ptr> VPtrAssociativeContainer; 71 typedef std::list<ver_ptr> VPtrSequenceContainer; 72 73 typedef VPtrSequenceContainer targets; 74 #if USE_ASSOCCONTAINER_BASED_DIRECTEDGRAPH 75 typedef VPtrAssociativeContainer vertices; 76 #else 77 typedef VPtrSequenceContainer vertices; 78 #endif 79 typedef targets::iterator target_iter; 80 typedef targets::const_iterator target_citer; 81 typedef vertices::iterator ver_iter; 82 typedef vertices::const_iterator ver_citer; 83 //not possible: typedef vertex::Address address; 84 typedef int address; 85 //not possible: typedef vertex::Size size; 86 typedef unsigned int size; 87 typedef std::vector<address> addresses; 88 89 private: 90 /// converts what is stored in the container to a smart ptr to the vertex vertex_ptr(const VPtrAssociativeContainer::value_type & v)91 static inline const ver_ptr& vertex_ptr(const VPtrAssociativeContainer::value_type& v) { 92 return v.second; 93 } vertex_ptr(const VPtrSequenceContainer::value_type & v)94 static inline const ver_ptr& vertex_ptr(const VPtrSequenceContainer::value_type& v) { 95 return v; 96 } 97 /// converts what is stored in the container to a smart ptr to the vertex vertex_ptr(VPtrAssociativeContainer::value_type & v)98 static inline ver_ptr& vertex_ptr(VPtrAssociativeContainer::value_type& v) { 99 return v.second; 100 } vertex_ptr(VPtrSequenceContainer::value_type & v)101 static inline ver_ptr& vertex_ptr(VPtrSequenceContainer::value_type& v) { 102 return v; 103 } 104 105 public: 106 107 /** Creates an empty DAG. Actual initialization of the graph 108 must be done using append_target */ 109 DirectedGraph(); 110 ~DirectedGraph(); 111 112 /// Returns the number of vertices num_vertices()113 unsigned int num_vertices() const { return stack_.size(); } 114 #if 0 115 /// Returns all vertices 116 const vertices& all_vertices() const { return stack_; } 117 /// Returns all targets 118 const targets& all_targets() const { return targets_; } 119 #endif 120 /// Find vertex v or it's equivalent on the graph. Return null pointer if not found. 121 /// Computational cost for a graph based on a nonassociative container may be high find(const SafePtr<DGVertex> & v)122 const SafePtr<DGVertex>& find(const SafePtr<DGVertex>& v) const { return vertex_is_on(v); } 123 124 /** appends v to the graph. If v's copy is already on the graph, return the pointer 125 to the copy. Else return pointer to *v. 126 */ 127 SafePtr<DGVertex> append_vertex(const SafePtr<DGVertex>& v); 128 129 /** non-template append_target appends the vertex to the graph as a target 130 */ 131 void append_target(const SafePtr<DGVertex>&); 132 133 /** append_target appends I to the graph as a target vertex and applies 134 RR to it. append_target can be called multiple times on the same 135 graph if more than one target vertex is needed. 136 137 I must derive from DGVertex. RR must derive from RecurrenceRelation. 138 RR has a constructor which takes const I& as the only argument. 139 RR must have a public member const I* child(unsigned int) . 140 141 NOTE TO SELF : need to implement these restrictions using 142 standard Bjarne Stroustrup's approach. 143 144 */ append_target(const SafePtr<I> & target)145 template <class I, class RR> void append_target(const SafePtr<I>& target) { 146 target->make_a_target(); 147 recurse<I,RR>(target); 148 } 149 150 /** apply_to_all applies RR to all vertices already on the graph. 151 152 RR must derive from RecurrenceRelation. RR must define TargetType 153 as a typedef. 154 RR must have a public member const DGVertex* child(unsigned int) . 155 156 NOTE TO SELF : need to implement these restrictions using 157 standard Bjarne Stroustrup's approach. 158 159 */ apply_to_all()160 template <class RR> void apply_to_all() { 161 typedef typename RR::TargetType TT; 162 typedef vertices::const_iterator citer; 163 typedef vertices::iterator iter; 164 for(iter v=stack_.begin(); v!=stack_.end(); ++v) { 165 ver_ptr& vptr = vertex_ptr(*v); 166 if ((vptr)->num_exit_arcs() != 0) 167 continue; 168 SafePtr<TT> tptr = dynamic_pointer_cast<TT,DGVertex>(v); 169 if (tptr == 0) 170 continue; 171 172 SafePtr<RR> rr0(new RR(tptr)); 173 const int num_children = rr0->num_children(); 174 175 for(int c=0; c<num_children; c++) { 176 177 SafePtr<DGVertex> child = rr0->child(c); 178 SafePtr<DGArc> arc(new DGArcRel<RR>(tptr,child,rr0)); 179 tptr->add_exit_arc(arc); 180 181 recurse<RR>(child); 182 183 } 184 } 185 } 186 187 /** after all append_target's have been called, apply(strategy,tactic) 188 constructs a graph. strategy specifies how to apply recurrence relations. 189 The goal of strategies is to connect the target vertices to simpler, precomputable vertices. 190 There usually are many ways to reduce a vertex. 191 Tactic specifies which of these possibilities to choose. 192 */ 193 void apply(const SafePtr<Strategy>& strategy, 194 const SafePtr<Tactic>& tactic); 195 196 typedef void (DGVertex::* DGVertexMethodPtr)(); 197 /** apply_at<method>(vertex) calls method() on vertex and all of its descendants 198 */ 199 template <DGVertexMethodPtr method> apply_at(const SafePtr<DGVertex> & vertex)200 void apply_at(const SafePtr<DGVertex>& vertex) const { 201 ((vertex.get())->*method)(); 202 typedef DGVertex::ArcSetType::const_iterator aciter; 203 const aciter abegin = vertex->first_exit_arc(); 204 const aciter aend = vertex->plast_exit_arc(); 205 for(aciter a=abegin; a!=aend; ++a) 206 apply_at<method>((*a)->dest()); 207 } 208 209 /** calls Method(v) for each v, iterating in forward direction */ 210 template <class Method> foreach(Method & m)211 void foreach(Method& m) { 212 typedef vertices::const_iterator citer; 213 typedef vertices::iterator iter; 214 for(iter v=stack_.begin(); v!=stack_.end(); ++v) { 215 ver_ptr& vptr = vertex_ptr(*v); 216 m(vptr); 217 } 218 } 219 220 /** calls Method(v) for each v, iterating in forward direction */ 221 template <class Method> foreach(Method & m)222 void foreach(Method& m) const { 223 typedef vertices::const_iterator citer; 224 typedef vertices::iterator iter; 225 for(citer v=stack_.begin(); v!=stack_.end(); ++v) { 226 const ver_ptr& vptr = vertex_ptr(*v); 227 m(vptr); 228 } 229 } 230 /** calls Method(v) for each v, iterating in reverse direction */ 231 template <class Method> rforeach(Method & m)232 void rforeach(Method& m) { 233 typedef vertices::const_reverse_iterator criter; 234 typedef vertices::reverse_iterator riter; 235 for(riter v=stack_.rbegin(); v!=stack_.rend(); ++v) { 236 ver_ptr& vptr = vertex_ptr(*v); 237 m(vptr); 238 } 239 } 240 241 /** calls Method(v) for each v, iterating in reverse direction */ 242 template <class Method> rforeach(Method & m)243 void rforeach(Method& m) const { 244 typedef vertices::const_reverse_iterator criter; 245 typedef vertices::reverse_iterator riter; 246 for(criter v=stack_.rbegin(); v!=stack_.rend(); ++v) { 247 const ver_ptr& vptr = vertex_ptr(*v); 248 m(vptr); 249 } 250 } 251 252 /** after Strategy has been applied, simple recurrence relations need to be 253 optimized away. optimize_rr_out() will replace all simple recurrence relations 254 with code representing them. 255 */ 256 void optimize_rr_out(const SafePtr<CodeContext>& context); 257 258 /** after all apply's have been called, traverse() 259 construct a heuristic order of traversal for the graph. 260 */ 261 void traverse(); 262 263 /** update func_names_ 264 */ 265 void update_func_names(); 266 267 /// Prints out call sequence 268 void debug_print_traversal(std::ostream& os) const; 269 270 /** 271 Prints out the graph in format understood by program "dot" 272 of package "graphviz". If symbols is true then label vertices 273 using their symbols rather than (descriptive) labels. 274 */ 275 void print_to_dot(bool symbols, std::ostream& os = std::cout) const; 276 277 /** 278 Generates code for the current computation using context. 279 dims specifies the implicit dimensions, 280 args specifies the code symbols for the arguments to the function, 281 label specifies the tag for the computation, 282 decl specifies the stream to receive declaration code, 283 code specifies the stream to receive the definition code 284 */ 285 void generate_code(const SafePtr<CodeContext>& context, const SafePtr<MemoryManager>& memman, 286 const SafePtr<ImplicitDimensions>& dims, const SafePtr<CodeSymbols>& args, 287 const std::string& label, 288 std::ostream& decl, std::ostream& code); 289 290 /** Resets the graph and all vertices. The stack of unresolved recurrence 291 relations is preserved. 292 */ 293 void reset(); 294 295 /** num_children_on(rr) returns the number of children of rr which 296 are already on this graph. 297 */ 298 template <class RR> 299 unsigned int num_children_on(const SafePtr<RR> & rr)300 num_children_on(const SafePtr<RR>& rr) const { 301 unsigned int nchildren = rr->num_children(); 302 unsigned int nchildren_on_stack = 0; 303 for(unsigned int c=0; c<nchildren; c++) { 304 if (!vertex_is_on(rr->rr_child(c))) 305 continue; 306 else 307 nchildren_on_stack++; 308 } 309 310 return nchildren_on_stack; 311 } 312 313 /// Returns the registry registry()314 SafePtr<GraphRegistry>& registry() { return registry_; } registry()315 const SafePtr<GraphRegistry>& registry() const { return registry_; } 316 317 /// return the graph label label()318 const std::string& label() const { return label_; } 319 /// sets label to \c new_label set_label(const std::string & new_label)320 void set_label(const std::string& new_label) { label_ = new_label; } 321 322 /// return true if there are vertices with 0 children but not pre-computed 323 bool missing_prerequisites() const; 324 325 private: 326 327 /// contains vertices 328 vertices stack_; 329 /// refers to targets, cannot be an associative container -- order of iteration over targets is important 330 targets targets_; 331 /// addresses of blocks which accumulate targets 332 addresses target_accums_; 333 334 // graph label, used for annotating internal work, e.g. graphviz plots 335 std::string label_; 336 337 typedef std::map<std::string,bool> FuncNameContainer; 338 /** Maintains the list of names of functions calls to which have been generated so far. 339 It is used to generate include statements. 340 */ 341 FuncNameContainer func_names_; 342 343 #if !USE_ASSOCCONTAINER_BASED_DIRECTEDGRAPH 344 static const unsigned int default_size_ = 100; 345 #endif 346 347 // maintains data about the graph which does not belong IN the graph 348 SafePtr<GraphRegistry> registry_; 349 // maintains private data about the graph which does not belong IN the graph 350 SafePtr<InternalGraphRegistry> iregistry_; 351 352 /// Access the internal registry iregistry()353 SafePtr<InternalGraphRegistry>& iregistry() { return iregistry_; } iregistry()354 const SafePtr<InternalGraphRegistry>& iregistry() const { return iregistry_; } 355 356 /** adds a vertex to the graph. If the vertex already found on the graph 357 then the vertex is not added and the function returns false */ 358 SafePtr<DGVertex> add_vertex(const SafePtr<DGVertex>&); 359 /** same as add_vertex(), only assumes that there's no equivalent vertex on the graph (see vertex_is_on) */ 360 void add_new_vertex(const SafePtr<DGVertex>&); 361 /// returns true if vertex if already on graph 362 const SafePtr<DGVertex>& vertex_is_on(const SafePtr<DGVertex>& vertex) const; 363 /// removes vertex from the graph. may throw CannotPerformOperation 364 void del_vertex(vertices::iterator&); 365 /** This function is used to implement (recursive) append_target(). 366 vertex is appended to the graph and then RR is applied to is. 367 */ recurse(const SafePtr<I> & vertex)368 template <class I, class RR> void recurse(const SafePtr<I>& vertex) { 369 SafePtr<DGVertex> dgvertex = add_vertex(vertex); 370 if (dgvertex != vertex) 371 return; 372 373 SafePtr<RR> rr0(new RR(vertex)); 374 const int num_children = rr0->num_children(); 375 376 for(int c=0; c<num_children; c++) { 377 378 SafePtr<DGVertex> child = rr0->child(c); 379 SafePtr<DGArc> arc(new DGArcRel<RR>(vertex,child,rr0)); 380 vertex->add_exit_arc(arc); 381 382 SafePtr<I> child_cast = dynamic_pointer_cast<I,DGVertex>(child); 383 if (child_cast == 0) 384 throw std::runtime_error("DirectedGraph::recurse(const SafePtr<I>& vertex) -- dynamic cast failed, most probably this is a logic error!"); 385 recurse<I,RR>(child_cast); 386 387 } 388 } 389 390 /** This function is used to implement (recursive) apply_to_all(). 391 RR is applied to vertex and all its children. 392 */ recurse(const SafePtr<DGVertex> & vertex)393 template <class RR> void recurse(const SafePtr<DGVertex>& vertex) { 394 SafePtr<DGVertex> dgvertex = add_vertex(vertex); 395 if (dgvertex != vertex) 396 return; 397 398 typedef typename RR::TargetType TT; 399 SafePtr<TT> tptr = dynamic_pointer_cast<TT,DGVertex>(vertex); 400 if (tptr == 0) 401 return; 402 403 SafePtr<RR> rr0(new RR(tptr)); 404 const int num_children = rr0->num_children(); 405 406 for(int c=0; c<num_children; c++) { 407 408 SafePtr<DGVertex> child = rr0->child(c); 409 SafePtr<DGArc> arc(new DGArcRel<RR>(vertex,child,rr0)); 410 vertex->add_exit_arc(arc); 411 412 recurse<RR>(child); 413 } 414 } 415 416 /** This function is used to implement (recursive) apply(). 417 strategy and tactic are applied to vertex and all its children. 418 */ 419 void apply_to(const SafePtr<DGVertex>& vertex, 420 const SafePtr<Strategy>& strategy, 421 const SafePtr<Tactic>& tactic); 422 /// This function insert expr of type AlgebraicOperator<DGVertex> into the graph 423 SafePtr<DGVertex> insert_expr_at(const SafePtr<DGVertex>& where, const SafePtr< AlgebraicOperator<DGVertex> >& expr); 424 /// This function replaces RecurrenceRelations with concrete arithemtical expressions 425 void replace_rr_with_expr(); 426 /// This function gets rid of trivial math such as multiplication/division by 1.0, etc. 427 void remove_trivial_arithmetics(); 428 /** This function gets rid of nodes which are connected 429 to their equivalents (such as (ss|ss) shell quartet can only be connected to (ss|ss) integral) 430 */ 431 void handle_trivial_nodes(const SafePtr<CodeContext>& context); 432 /// This functions removes vertices not connected to other vertices 433 void remove_disconnected_vertices(); 434 /** Finds (binary) subtrees. The subtrees correspond to a single-line code (no intermediates 435 are used in other expressions) 436 */ 437 void find_subtrees(); 438 /** Finds (binary) subtrees starting (recursively) at v. 439 */ 440 void find_subtrees_from(const SafePtr<DGVertex>& v); 441 /** If v1 and v2 are connected by DGArcDirect and all entry arcs to v1 are of the DGArcDirect type as well, 442 this function will reattach all arcs extering v1 to v2 and remove v1 from the graph altogether. 443 May throw CannotPerformOperation. 444 */ 445 bool remove_vertex_at(const SafePtr<DGVertex>& v1, const SafePtr<DGVertex>& v2); 446 447 // Which vertex is the first to compute 448 SafePtr<DGVertex> first_to_compute_; 449 // prepare_to_traverse must be called before actual traversal 450 void prepare_to_traverse(); 451 // traverse_from(arc) build recurively the traversal order 452 void traverse_from(const SafePtr<DGArc>&); 453 // schedule_computation(vertex) puts vertex first in the computation order 454 void schedule_computation(const SafePtr<DGVertex>&); 455 456 // Compute addresses on stack assuming that quantities larger than min_size_to_alloc to be allocated on stack 457 void allocate_mem(const SafePtr<MemoryManager>& memman, 458 const SafePtr<ImplicitDimensions>& dims, 459 unsigned int min_size_to_alloc = 1); 460 // Assign symbols to the vertices 461 void assign_symbols(const SafePtr<CodeContext>& context, const SafePtr<ImplicitDimensions>& dims); 462 // If v is an AlgebraicOperator, assign (recursively) symbol to the operator. All other must have been already assigned 463 void assign_oper_symbol(const SafePtr<CodeContext>& context, SafePtr<DGVertex>& v); 464 // Print the code using symbols generated with assign_symbols() 465 void print_def(const SafePtr<CodeContext>& context, std::ostream& os, 466 const SafePtr<ImplicitDimensions>& dims, 467 const SafePtr<CodeSymbols>& args); 468 469 /** Returns true if cannot enclose the code in a vector loop 470 Possible reason: the traversal path contains a RecurrenceRelation that generates a function call 471 (most do except IntegralSet_to_Integrals; \sa RecurrentRelation::is_simple() ) 472 */ 473 bool cannot_enclose_in_outer_vloop() const; 474 475 }; 476 477 // 478 // Nonmember utilities 479 // 480 481 /// converts what is stored in the container to a smart ptr to the vertex vertex_ptr(const DirectedGraph::VPtrAssociativeContainer::value_type & v)482 inline const DirectedGraph::ver_ptr& vertex_ptr(const DirectedGraph::VPtrAssociativeContainer::value_type& v) { 483 return v.second; 484 } vertex_ptr(const DirectedGraph::VPtrSequenceContainer::value_type & v)485 inline const DirectedGraph::ver_ptr& vertex_ptr(const DirectedGraph::VPtrSequenceContainer::value_type& v) { 486 return v; 487 } 488 /// converts what is stored in the container to a smart ptr to the vertex vertex_ptr(DirectedGraph::VPtrAssociativeContainer::value_type & v)489 inline DirectedGraph::ver_ptr& vertex_ptr(DirectedGraph::VPtrAssociativeContainer::value_type& v) { 490 return v.second; 491 } vertex_ptr(DirectedGraph::VPtrSequenceContainer::value_type & v)492 inline DirectedGraph::ver_ptr& vertex_ptr(DirectedGraph::VPtrSequenceContainer::value_type& v) { 493 return v; 494 } 495 496 #if USE_ASSOCCONTAINER_BASED_DIRECTEDGRAPH 497 inline DirectedGraph::key_type key(const DGVertex& v); 498 #endif 499 500 // 501 // Nonmember predicates 502 // 503 504 /// return true if there are non-unrolled targets 505 bool nonunrolled_targets(const DirectedGraph::targets& targets); 506 507 /// extracts external symbols and RRs from the graph 508 void extract_symbols(const SafePtr<DirectedGraph>& dg); 509 510 // use these functors with DirectedGraph::foreach 511 struct PrerequisitesExtractor { 512 std::deque< SafePtr<DGVertex> > vertices; 513 void operator()(const SafePtr<DGVertex>& v); 514 }; 515 struct VertexPrinter { VertexPrinterVertexPrinter516 VertexPrinter(std::ostream& ostr) : os(ostr) {} 517 std::ostream& os; 518 void operator()(const SafePtr<DGVertex>& v); 519 }; 520 521 }; 522 523 524 #endif 525