1 /*
2  *  Copyright (C) 2004-2021 Edward F. Valeev
3  *
4  *  This file is part of Libint.
5  *
6  *  Libint is free software: you can redistribute it and/or modify
7  *  it under the terms of the GNU General Public License as published by
8  *  the Free Software Foundation, either version 3 of the License, or
9  *  (at your option) any later version.
10  *
11  *  Libint is distributed in the hope that it will be useful,
12  *  but WITHOUT ANY WARRANTY; without even the implied warranty of
13  *  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
14  *  GNU General Public License for more details.
15  *
16  *  You should have received a copy of the GNU General Public License
17  *  along with Libint.  If not, see <http://www.gnu.org/licenses/>.
18  *
19  */
20 
21 #include <fstream>
22 #include <limits>
23 
24 #include <rr.h>
25 #include <dg.h>
26 #include <strategy.h>
27 #include <code.h>
28 #include <graph_registry.h>
29 #include <extract.h>
30 #include <algebra.h>
31 #include <context.h>
32 #include <integral.h>
33 #include <task.h>
34 #include <prefactors.h>
35 #include <singl_stack.h>
36 
37 using namespace std;
38 using namespace libint2;
39 using namespace libint2::prefactor;
40 
RecurrenceRelation()41 RecurrenceRelation::RecurrenceRelation() :
42   nflops_(0), expr_()
43 {
44 }
45 
~RecurrenceRelation()46 RecurrenceRelation::~RecurrenceRelation()
47 {
48 }
49 
50 //
51 // If there is no generic equivalent, generate explicit code for this recurrence relation:
52 // 1) append target and children to a DirectedGraph dg
53 // 2) set their code symbols
54 // 3) apply IntSet_to_Ints
55 // 4) Apply RRs such that no additional vertices appear
56 // 5) call dg->generate_code()
57 //
58 void
generate_code(const SafePtr<CodeContext> & context,const SafePtr<ImplicitDimensions> & dims,const std::string & funcname,std::ostream & decl,std::ostream & def)59 RecurrenceRelation::generate_code(const SafePtr<CodeContext>& context,
60                                   const SafePtr<ImplicitDimensions>& dims,
61                                   const std::string& funcname,
62                                   std::ostream& decl, std::ostream& def)
63 {
64   //
65   // Check if there is a generic equivalent that can be used
66   //
67   if (this->has_generic(context->cparams())) {
68     generate_generic_code(context,dims,funcname,decl,def);
69     return;
70   }
71 
72   const SafePtr<DGVertex> target_vptr = rr_target();
73 #if DEBUG
74   std::cout << "RecurrenceRelation::generate_code: target = " << target_vptr->label() << std::endl;
75 #endif
76 
77   const SafePtr<CompilationParameters>& cparams = context->cparams();
78   SafePtr<DirectedGraph> dg(new DirectedGraph);
79   dg->set_label(context->label_to_name(label_to_funcname(funcname)));
80   generate_graph_(dg);
81 
82   // Intermediates in RR code are either are automatic variables or have to go on vstack
83   dg->registry()->stack_name("inteval->vstack");
84   // No need to return the targets via inteval's targets
85   dg->registry()->return_targets(false);
86 
87   // check if CSE to be performed
88   typedef IntegralSet<IncableBFSet> ISet;
89   SafePtr<ISet> target = dynamic_pointer_cast<ISet,DGVertex>(target_vptr);
90   if (target) {
91     //
92     // do CSE only if max_am <= cparams->max_am_opt()
93     //
94     const unsigned int np = target->num_part();
95     unsigned int max_am = 0;
96     // bra
97     for(unsigned int p=0; p<np; p++) {
98       const unsigned int nf = target->num_func_bra(p);
99       for(unsigned int f=0; f<nf; f++) {
100 	// Assuming shells here
101 	const unsigned int am = target->bra(p,f).norm();
102 	using std::max;
103 	max_am = max(max_am,am);
104       }
105     }
106     // ket
107     for(unsigned int p=0; p<np; p++) {
108       const unsigned int nf = target->num_func_ket(p);
109       for(unsigned int f=0; f<nf; f++) {
110 	// Assuming shells here
111 	const unsigned int am = target->ket(p,f).norm();
112 	using std::max;
113 	max_am = max(max_am,am);
114       }
115     }
116     const bool need_to_optimize = (max_am <= cparams->max_am_opt());
117     dg->registry()->do_cse(need_to_optimize);
118   }
119   dg->registry()->condense_expr(condense_expr(std::numeric_limits<unsigned int>::max(),cparams->max_vector_length()>1));
120   dg->registry()->ignore_missing_prereqs(true);  // assume all prerequisites are available -- if some are not, something is VERY broken
121 
122 #if PRINT_DAG_GRAPHVIZ
123   {
124     std::basic_ofstream<char> dotfile(dg->label() + ".strat.dot");
125     dg->print_to_dot(false,dotfile);
126   }
127 #endif
128 
129   // Assign symbols for the target and source integral sets
130   SafePtr<CodeSymbols> symbols(new CodeSymbols);
131   assign_symbols_(symbols);
132   // Traverse the graph
133   dg->optimize_rr_out(context);
134   dg->traverse();
135 #if PRINT_DAG_GRAPHVIZ
136     {
137       std::basic_ofstream<char> dotfile(dg->label() + ".expr.dot");
138       dg->print_to_dot(false,dotfile);
139     }
140 #endif
141   // Generate code
142   SafePtr<MemoryManager> memman(new WorstFitMemoryManager());
143   SafePtr<ImplicitDimensions> localdims = adapt_dims_(dims);
144   dg->generate_code(context,memman,localdims,symbols,funcname,decl,def);
145 
146   // extract all external symbols -- these will be members of the evaluator structure
147   SafePtr<ExtractExternSymbols> extractor(new ExtractExternSymbols);
148   dg->foreach(*extractor);
149   const ExtractExternSymbols::Symbols& externsymbols = extractor->symbols();
150 
151 #if 0
152   // print out the symbols
153   std::cout << "Recovered symbols from DirectedGraph for " << label() << std::endl;
154   typedef ExtractExternSymbols::Symbols::const_iterator citer;
155   citer end = externsymbols.end();
156   for(citer t=externsymbols.begin(); t!=end; ++t)
157     std::cout << *t << std::endl;
158 #endif
159 
160 #if PRINT_DAG_GRAPHVIZ
161     {
162       std::basic_ofstream<char> dotfile(dg->label() + ".symb.dot");
163       dg->print_to_dot(false,dotfile);
164     }
165 #endif
166 
167   // get this RR InstanceID
168   RRStack::InstanceID myid = RRStack::Instance()->find(EnableSafePtrFromThis<this_type>::SafePtr_from_this()).first;
169 
170   // For each task which requires this RR:
171   // 1) update max stack size
172   // 2) append external symbols from this RR to its list
173   LibraryTaskManager& taskmgr = LibraryTaskManager::Instance();
174   typedef LibraryTaskManager::TasksCIter tciter;
175   const tciter tend = taskmgr.plast();
176   for(tciter t=taskmgr.first(); t!=tend; ++t) {
177     const SafePtr<TaskExternSymbols> tsymbols = t->symbols();
178     if (tsymbols->find(myid)) {
179       // update max stack size
180       t->params()->max_vector_stack_size(memman->max_memory_used());
181       // add external symbols
182       tsymbols->add(externsymbols);
183     }
184   }
185 
186   dg->reset();
187 }
188 
189 namespace libint2 {
190   // generate_generic_code reuses this function from dg.cc:
191   extern std::string declare_function(const SafePtr<CodeContext>& context, const SafePtr<ImplicitDimensions>& dims,
192                                       const SafePtr<CodeSymbols>& args, const std::string& tlabel, const std::string& function_descr,
193                                       std::ostream& decl);
194 }
195 
196 void
generate_generic_code(const SafePtr<CodeContext> & context,const SafePtr<ImplicitDimensions> & dims,const std::string & funcname,std::ostream & decl,std::ostream & def)197 RecurrenceRelation::generate_generic_code(const SafePtr<CodeContext>& context,
198                                           const SafePtr<ImplicitDimensions>& dims,
199                                           const std::string& funcname,
200                                           std::ostream& decl, std::ostream& def)
201 {
202   const SafePtr<DGVertex> target_vptr = rr_target();
203   std::cout << "RecurrenceRelation::generate_generic_code: target = " << target_vptr->label() << std::endl;
204 
205   LibraryTaskManager& taskmgr = LibraryTaskManager::Instance();
206   const std::string tlabel = taskmgr.current().label();
207   SafePtr<ImplicitDimensions> localdims = adapt_dims_(dims);
208   // Assign symbols for the target and source integral sets
209   SafePtr<CodeSymbols> symbols(new CodeSymbols);
210   assign_symbols_(symbols);
211 
212   // declare function
213   const std::string func_decl = declare_function(context,localdims,symbols,tlabel,funcname,decl);
214 
215   //
216   // Generate function's definition
217   //
218 
219   // include standard headers
220   def << context->std_header();
221   //         + generic code declaration
222   def << "#include <"
223       << this->generic_header()
224       << ">" << endl;
225   def << endl;
226 
227   // start the body ...
228   def << context->code_prefix();
229   def << func_decl << context->open_block() << endl;
230   def << context->std_function_header();
231 
232   // ... fill the body
233   def << this->generic_instance(context,symbols) << endl;
234 
235   // ... end the body
236   def << context->close_block() << endl;
237   def << context->code_postfix();
238 }
239 
240 SafePtr<DirectedGraph>
generate_graph_(const SafePtr<DirectedGraph> & dg)241 RecurrenceRelation::generate_graph_(const SafePtr<DirectedGraph>& dg)
242 {
243   dg->append_target(rr_target());
244   for(unsigned int c=0; c<num_children(); c++)
245     dg->append_vertex(rr_child(c));
246 #if DEBUG
247   cout << "RecurrenceRelation::generate_code -- the number of integral sets = " << dg->num_vertices() << endl;
248 #endif
249   SafePtr<Strategy> strat(new Strategy);
250   SafePtr<Tactic> ntactic(new NullTactic);
251   // Always need to unroll integral sets first
252   dg->registry()->unroll_threshold(std::numeric_limits<unsigned int>::max());
253   dg->apply(strat,ntactic);
254 #if DEBUG
255   cout << "RecurrenceRelation::generate_code -- the number of integral sets + integrals = " << dg->num_vertices() << endl;
256 #endif
257   // Mark children sets and their descendants to not compute
258   for(unsigned int c=0; c<num_children(); c++)
259     dg->apply_at<&DGVertex::not_need_to_compute>(rr_child(c));
260   // Apply recurrence relations using existing vertices on the graph (i.e.
261   // such that no new vertices appear)
262   SafePtr<Tactic> ztactic(new FewestNewVerticesTactic(dg));
263   dg->apply(strat,ztactic);
264 #if DEBUG
265   cout << "RecurrenceRelation::generate_code -- should be same as previous = " << dg->num_vertices() << endl;
266 #endif
267 
268   return dg;
269 }
270 
271 void
assign_symbols_(SafePtr<CodeSymbols> & symbols)272 RecurrenceRelation::assign_symbols_(SafePtr<CodeSymbols>& symbols)
273 {
274   // Set symbols on the target and children sets
275   rr_target()->set_symbol("target");
276   symbols->append_symbol("target");
277   for(unsigned int c=0; c<num_children(); c++) {
278     ostringstream oss;
279     oss << "src" << c;
280     string symb = oss.str();
281     rr_child(c)->set_symbol(symb);
282     symbols->append_symbol(symb);
283   }
284 }
285 
286 SafePtr<ImplicitDimensions>
adapt_dims_(const SafePtr<ImplicitDimensions> & dims) const287 RecurrenceRelation::adapt_dims_(const SafePtr<ImplicitDimensions>& dims) const
288 {
289   return dims;
290 }
291 
292 const std::string&
label() const293 RecurrenceRelation::label() const {
294   if (label_.empty())
295     label_ = generate_label();
296   return label_;
297 }
298 
299 std::string
description() const300 RecurrenceRelation::description() const
301 {
302   const std::string descr = label();
303   return descr;
304 }
305 
306 void
add_expr(const SafePtr<ExprType> & expr,int minus)307 RecurrenceRelation::add_expr(const SafePtr<ExprType>& expr, int minus)
308 {
309   if (expr_ == 0) {
310     if (minus != -1) {
311       expr_ = expr;
312     }
313     else {
314       using libint2::prefactor::Scalar;
315       SafePtr<ExprType> negative(new ExprType(ExprType::OperatorTypes::Times,expr,Scalar(-1.0)));
316       expr_ = negative;
317       ++nflops_;
318     }
319   }
320   else {
321     if (minus != -1) {
322       SafePtr<ExprType> sum(new ExprType(ExprType::OperatorTypes::Plus,expr_,expr));
323       expr_ = sum;
324       ++nflops_;
325     }
326     else {
327       SafePtr<ExprType> sum(new ExprType(ExprType::OperatorTypes::Minus,expr_,expr));
328       expr_ = sum;
329       ++nflops_;
330     }
331   }
332 }
333 
334 
335 bool
invariant_type() const336 RecurrenceRelation::invariant_type() const {
337   // By default, recurrence relations do not change the type of the functions, i.e. VRR applied to an integral over shells will produce integrals over shells
338   return true;
339 }
340 
341 std::string
spfunction_call(const SafePtr<CodeContext> & context,const SafePtr<ImplicitDimensions> & dims) const342 RecurrenceRelation::spfunction_call(const SafePtr<CodeContext>& context, const SafePtr<ImplicitDimensions>& dims) const
343 {
344   ostringstream os;
345   os << context->label_to_name(label_to_funcname(context->cparams()->api_prefix() + label()))
346     // First argument is the library object
347      << "(inteval, "
348     // Second is the target
349      << context->value_to_pointer(rr_target()->symbol());
350   // then come children
351   const unsigned int nchildren = num_children();
352   for(unsigned int c=0; c<nchildren; c++) {
353     os << ", " << context->value_to_pointer(rr_child(c)->symbol());
354   }
355   os << ")" << context->end_of_stat() << endl;
356   return os.str();
357 }
358 
359 bool
has_generic(const SafePtr<CompilationParameters> & cparams) const360 RecurrenceRelation::has_generic(const SafePtr<CompilationParameters>& cparams) const {
361   return false;
362 }
363 
364 std::string
generic_header() const365 RecurrenceRelation::generic_header() const {
366   throw std::logic_error("RecurrenceRelation::generic_header() -- should not be called! Check if DerivedRecurrenceRelation::generic_header() is implemented");
367 }
368 
369 std::string
generic_instance(const SafePtr<CodeContext> & context,const SafePtr<CodeSymbols> & args) const370 RecurrenceRelation::generic_instance(const SafePtr<CodeContext>& context, const SafePtr<CodeSymbols>& args) const {
371   throw std::logic_error("RecurrenceRelation::generic_instance() -- should not be called! Check if DerivedRecurrenceRelation::generic_instance() is implemented");
372 }
373 
374 size_t
size_of_children() const375 RecurrenceRelation::size_of_children() const {
376   const auto nchildren = this->num_children();
377   size_t result = 0;
378   for(auto c=0; c!=nchildren; ++c) {
379     result += this->rr_child(c)->size();
380   }
381   return result;
382 }
383 
384 namespace libint2 { namespace algebra {
385   /// these operators are extremely useful to write compact expressions
operator +(const SafePtr<DGVertex> & A,const SafePtr<DGVertex> & B)386   SafePtr<RecurrenceRelation::ExprType> operator+(const SafePtr<DGVertex>& A,
387                                                   const SafePtr<DGVertex>& B) {
388     typedef RecurrenceRelation::ExprType Oper;
389     return SafePtr<Oper>(new Oper(Oper::OperatorTypes::Plus,A,B));
390   }
operator -(const SafePtr<DGVertex> & A,const SafePtr<DGVertex> & B)391   SafePtr<RecurrenceRelation::ExprType> operator-(const SafePtr<DGVertex>& A,
392                                                   const SafePtr<DGVertex>& B) {
393     typedef RecurrenceRelation::ExprType Oper;
394     return SafePtr<Oper>(new Oper(Oper::OperatorTypes::Minus,A,B));
395   }
operator *(const SafePtr<DGVertex> & A,const SafePtr<DGVertex> & B)396   SafePtr<RecurrenceRelation::ExprType> operator*(const SafePtr<DGVertex>& A,
397                                                   const SafePtr<DGVertex>& B) {
398     typedef RecurrenceRelation::ExprType Oper;
399     return SafePtr<Oper>(new Oper(Oper::OperatorTypes::Times,A,B));
400   }
operator /(const SafePtr<DGVertex> & A,const SafePtr<DGVertex> & B)401   SafePtr<RecurrenceRelation::ExprType> operator/(const SafePtr<DGVertex>& A,
402                                                   const SafePtr<DGVertex>& B) {
403     typedef RecurrenceRelation::ExprType Oper;
404     return SafePtr<Oper>(new Oper(Oper::OperatorTypes::Divide,A,B));
405   }
operator +=(SafePtr<RecurrenceRelation::ExprType> & A,const SafePtr<DGVertex> & B)406   const SafePtr<RecurrenceRelation::ExprType>& operator+=(SafePtr<RecurrenceRelation::ExprType>& A,
407                                                           const SafePtr<DGVertex>& B) {
408     typedef RecurrenceRelation::ExprType Oper;
409     if (A) {
410       const SafePtr<Oper>& Sum = A + B;
411       A = Sum;
412     }
413     else
414       A = Scalar(0) + B;
415     return A;
416   }
operator -=(SafePtr<RecurrenceRelation::ExprType> & A,const SafePtr<DGVertex> & B)417   const SafePtr<RecurrenceRelation::ExprType>& operator-=(SafePtr<RecurrenceRelation::ExprType>& A,
418                                                           const SafePtr<DGVertex>& B) {
419     typedef RecurrenceRelation::ExprType Oper;
420     if (A) {
421       const SafePtr<Oper>& Diff = A - B;
422       A = Diff;
423     }
424     else
425       A = Scalar(0) - B;
426     return A;
427   }
operator *=(SafePtr<RecurrenceRelation::ExprType> & A,const SafePtr<DGVertex> & B)428   const SafePtr<RecurrenceRelation::ExprType>& operator*=(SafePtr<RecurrenceRelation::ExprType>& A,
429                                                           const SafePtr<DGVertex>& B) {
430     typedef RecurrenceRelation::ExprType Oper;
431     const SafePtr<Oper>& Product = A * B;
432     A = Product;
433     return A;
434   }
operator /=(SafePtr<RecurrenceRelation::ExprType> & A,const SafePtr<DGVertex> & B)435   const SafePtr<RecurrenceRelation::ExprType>& operator/=(SafePtr<RecurrenceRelation::ExprType>& A,
436                                                           const SafePtr<DGVertex>& B) {
437     typedef RecurrenceRelation::ExprType Oper;
438     const SafePtr<Oper>& Quotient = A / B;
439     A = Quotient;
440     return A;
441   }
442 } } // namespace libint2::algebra
443 
444 ///////////////
445