1 /*
2 * Copyright (C) 2004-2021 Edward F. Valeev
3 *
4 * This file is part of Libint.
5 *
6 * Libint is free software: you can redistribute it and/or modify
7 * it under the terms of the GNU General Public License as published by
8 * the Free Software Foundation, either version 3 of the License, or
9 * (at your option) any later version.
10 *
11 * Libint is distributed in the hope that it will be useful,
12 * but WITHOUT ANY WARRANTY; without even the implied warranty of
13 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14 * GNU General Public License for more details.
15 *
16 * You should have received a copy of the GNU General Public License
17 * along with Libint. If not, see <http://www.gnu.org/licenses/>.
18 *
19 */
20
21 #include <fstream>
22 #include <limits>
23
24 #include <rr.h>
25 #include <dg.h>
26 #include <strategy.h>
27 #include <code.h>
28 #include <graph_registry.h>
29 #include <extract.h>
30 #include <algebra.h>
31 #include <context.h>
32 #include <integral.h>
33 #include <task.h>
34 #include <prefactors.h>
35 #include <singl_stack.h>
36
37 using namespace std;
38 using namespace libint2;
39 using namespace libint2::prefactor;
40
RecurrenceRelation()41 RecurrenceRelation::RecurrenceRelation() :
42 nflops_(0), expr_()
43 {
44 }
45
~RecurrenceRelation()46 RecurrenceRelation::~RecurrenceRelation()
47 {
48 }
49
50 //
51 // If there is no generic equivalent, generate explicit code for this recurrence relation:
52 // 1) append target and children to a DirectedGraph dg
53 // 2) set their code symbols
54 // 3) apply IntSet_to_Ints
55 // 4) Apply RRs such that no additional vertices appear
56 // 5) call dg->generate_code()
57 //
58 void
generate_code(const SafePtr<CodeContext> & context,const SafePtr<ImplicitDimensions> & dims,const std::string & funcname,std::ostream & decl,std::ostream & def)59 RecurrenceRelation::generate_code(const SafePtr<CodeContext>& context,
60 const SafePtr<ImplicitDimensions>& dims,
61 const std::string& funcname,
62 std::ostream& decl, std::ostream& def)
63 {
64 //
65 // Check if there is a generic equivalent that can be used
66 //
67 if (this->has_generic(context->cparams())) {
68 generate_generic_code(context,dims,funcname,decl,def);
69 return;
70 }
71
72 const SafePtr<DGVertex> target_vptr = rr_target();
73 #if DEBUG
74 std::cout << "RecurrenceRelation::generate_code: target = " << target_vptr->label() << std::endl;
75 #endif
76
77 const SafePtr<CompilationParameters>& cparams = context->cparams();
78 SafePtr<DirectedGraph> dg(new DirectedGraph);
79 dg->set_label(context->label_to_name(label_to_funcname(funcname)));
80 generate_graph_(dg);
81
82 // Intermediates in RR code are either are automatic variables or have to go on vstack
83 dg->registry()->stack_name("inteval->vstack");
84 // No need to return the targets via inteval's targets
85 dg->registry()->return_targets(false);
86
87 // check if CSE to be performed
88 typedef IntegralSet<IncableBFSet> ISet;
89 SafePtr<ISet> target = dynamic_pointer_cast<ISet,DGVertex>(target_vptr);
90 if (target) {
91 //
92 // do CSE only if max_am <= cparams->max_am_opt()
93 //
94 const unsigned int np = target->num_part();
95 unsigned int max_am = 0;
96 // bra
97 for(unsigned int p=0; p<np; p++) {
98 const unsigned int nf = target->num_func_bra(p);
99 for(unsigned int f=0; f<nf; f++) {
100 // Assuming shells here
101 const unsigned int am = target->bra(p,f).norm();
102 using std::max;
103 max_am = max(max_am,am);
104 }
105 }
106 // ket
107 for(unsigned int p=0; p<np; p++) {
108 const unsigned int nf = target->num_func_ket(p);
109 for(unsigned int f=0; f<nf; f++) {
110 // Assuming shells here
111 const unsigned int am = target->ket(p,f).norm();
112 using std::max;
113 max_am = max(max_am,am);
114 }
115 }
116 const bool need_to_optimize = (max_am <= cparams->max_am_opt());
117 dg->registry()->do_cse(need_to_optimize);
118 }
119 dg->registry()->condense_expr(condense_expr(std::numeric_limits<unsigned int>::max(),cparams->max_vector_length()>1));
120 dg->registry()->ignore_missing_prereqs(true); // assume all prerequisites are available -- if some are not, something is VERY broken
121
122 #if PRINT_DAG_GRAPHVIZ
123 {
124 std::basic_ofstream<char> dotfile(dg->label() + ".strat.dot");
125 dg->print_to_dot(false,dotfile);
126 }
127 #endif
128
129 // Assign symbols for the target and source integral sets
130 SafePtr<CodeSymbols> symbols(new CodeSymbols);
131 assign_symbols_(symbols);
132 // Traverse the graph
133 dg->optimize_rr_out(context);
134 dg->traverse();
135 #if PRINT_DAG_GRAPHVIZ
136 {
137 std::basic_ofstream<char> dotfile(dg->label() + ".expr.dot");
138 dg->print_to_dot(false,dotfile);
139 }
140 #endif
141 // Generate code
142 SafePtr<MemoryManager> memman(new WorstFitMemoryManager());
143 SafePtr<ImplicitDimensions> localdims = adapt_dims_(dims);
144 dg->generate_code(context,memman,localdims,symbols,funcname,decl,def);
145
146 // extract all external symbols -- these will be members of the evaluator structure
147 SafePtr<ExtractExternSymbols> extractor(new ExtractExternSymbols);
148 dg->foreach(*extractor);
149 const ExtractExternSymbols::Symbols& externsymbols = extractor->symbols();
150
151 #if 0
152 // print out the symbols
153 std::cout << "Recovered symbols from DirectedGraph for " << label() << std::endl;
154 typedef ExtractExternSymbols::Symbols::const_iterator citer;
155 citer end = externsymbols.end();
156 for(citer t=externsymbols.begin(); t!=end; ++t)
157 std::cout << *t << std::endl;
158 #endif
159
160 #if PRINT_DAG_GRAPHVIZ
161 {
162 std::basic_ofstream<char> dotfile(dg->label() + ".symb.dot");
163 dg->print_to_dot(false,dotfile);
164 }
165 #endif
166
167 // get this RR InstanceID
168 RRStack::InstanceID myid = RRStack::Instance()->find(EnableSafePtrFromThis<this_type>::SafePtr_from_this()).first;
169
170 // For each task which requires this RR:
171 // 1) update max stack size
172 // 2) append external symbols from this RR to its list
173 LibraryTaskManager& taskmgr = LibraryTaskManager::Instance();
174 typedef LibraryTaskManager::TasksCIter tciter;
175 const tciter tend = taskmgr.plast();
176 for(tciter t=taskmgr.first(); t!=tend; ++t) {
177 const SafePtr<TaskExternSymbols> tsymbols = t->symbols();
178 if (tsymbols->find(myid)) {
179 // update max stack size
180 t->params()->max_vector_stack_size(memman->max_memory_used());
181 // add external symbols
182 tsymbols->add(externsymbols);
183 }
184 }
185
186 dg->reset();
187 }
188
189 namespace libint2 {
190 // generate_generic_code reuses this function from dg.cc:
191 extern std::string declare_function(const SafePtr<CodeContext>& context, const SafePtr<ImplicitDimensions>& dims,
192 const SafePtr<CodeSymbols>& args, const std::string& tlabel, const std::string& function_descr,
193 std::ostream& decl);
194 }
195
196 void
generate_generic_code(const SafePtr<CodeContext> & context,const SafePtr<ImplicitDimensions> & dims,const std::string & funcname,std::ostream & decl,std::ostream & def)197 RecurrenceRelation::generate_generic_code(const SafePtr<CodeContext>& context,
198 const SafePtr<ImplicitDimensions>& dims,
199 const std::string& funcname,
200 std::ostream& decl, std::ostream& def)
201 {
202 const SafePtr<DGVertex> target_vptr = rr_target();
203 std::cout << "RecurrenceRelation::generate_generic_code: target = " << target_vptr->label() << std::endl;
204
205 LibraryTaskManager& taskmgr = LibraryTaskManager::Instance();
206 const std::string tlabel = taskmgr.current().label();
207 SafePtr<ImplicitDimensions> localdims = adapt_dims_(dims);
208 // Assign symbols for the target and source integral sets
209 SafePtr<CodeSymbols> symbols(new CodeSymbols);
210 assign_symbols_(symbols);
211
212 // declare function
213 const std::string func_decl = declare_function(context,localdims,symbols,tlabel,funcname,decl);
214
215 //
216 // Generate function's definition
217 //
218
219 // include standard headers
220 def << context->std_header();
221 // + generic code declaration
222 def << "#include <"
223 << this->generic_header()
224 << ">" << endl;
225 def << endl;
226
227 // start the body ...
228 def << context->code_prefix();
229 def << func_decl << context->open_block() << endl;
230 def << context->std_function_header();
231
232 // ... fill the body
233 def << this->generic_instance(context,symbols) << endl;
234
235 // ... end the body
236 def << context->close_block() << endl;
237 def << context->code_postfix();
238 }
239
240 SafePtr<DirectedGraph>
generate_graph_(const SafePtr<DirectedGraph> & dg)241 RecurrenceRelation::generate_graph_(const SafePtr<DirectedGraph>& dg)
242 {
243 dg->append_target(rr_target());
244 for(unsigned int c=0; c<num_children(); c++)
245 dg->append_vertex(rr_child(c));
246 #if DEBUG
247 cout << "RecurrenceRelation::generate_code -- the number of integral sets = " << dg->num_vertices() << endl;
248 #endif
249 SafePtr<Strategy> strat(new Strategy);
250 SafePtr<Tactic> ntactic(new NullTactic);
251 // Always need to unroll integral sets first
252 dg->registry()->unroll_threshold(std::numeric_limits<unsigned int>::max());
253 dg->apply(strat,ntactic);
254 #if DEBUG
255 cout << "RecurrenceRelation::generate_code -- the number of integral sets + integrals = " << dg->num_vertices() << endl;
256 #endif
257 // Mark children sets and their descendants to not compute
258 for(unsigned int c=0; c<num_children(); c++)
259 dg->apply_at<&DGVertex::not_need_to_compute>(rr_child(c));
260 // Apply recurrence relations using existing vertices on the graph (i.e.
261 // such that no new vertices appear)
262 SafePtr<Tactic> ztactic(new FewestNewVerticesTactic(dg));
263 dg->apply(strat,ztactic);
264 #if DEBUG
265 cout << "RecurrenceRelation::generate_code -- should be same as previous = " << dg->num_vertices() << endl;
266 #endif
267
268 return dg;
269 }
270
271 void
assign_symbols_(SafePtr<CodeSymbols> & symbols)272 RecurrenceRelation::assign_symbols_(SafePtr<CodeSymbols>& symbols)
273 {
274 // Set symbols on the target and children sets
275 rr_target()->set_symbol("target");
276 symbols->append_symbol("target");
277 for(unsigned int c=0; c<num_children(); c++) {
278 ostringstream oss;
279 oss << "src" << c;
280 string symb = oss.str();
281 rr_child(c)->set_symbol(symb);
282 symbols->append_symbol(symb);
283 }
284 }
285
286 SafePtr<ImplicitDimensions>
adapt_dims_(const SafePtr<ImplicitDimensions> & dims) const287 RecurrenceRelation::adapt_dims_(const SafePtr<ImplicitDimensions>& dims) const
288 {
289 return dims;
290 }
291
292 const std::string&
label() const293 RecurrenceRelation::label() const {
294 if (label_.empty())
295 label_ = generate_label();
296 return label_;
297 }
298
299 std::string
description() const300 RecurrenceRelation::description() const
301 {
302 const std::string descr = label();
303 return descr;
304 }
305
306 void
add_expr(const SafePtr<ExprType> & expr,int minus)307 RecurrenceRelation::add_expr(const SafePtr<ExprType>& expr, int minus)
308 {
309 if (expr_ == 0) {
310 if (minus != -1) {
311 expr_ = expr;
312 }
313 else {
314 using libint2::prefactor::Scalar;
315 SafePtr<ExprType> negative(new ExprType(ExprType::OperatorTypes::Times,expr,Scalar(-1.0)));
316 expr_ = negative;
317 ++nflops_;
318 }
319 }
320 else {
321 if (minus != -1) {
322 SafePtr<ExprType> sum(new ExprType(ExprType::OperatorTypes::Plus,expr_,expr));
323 expr_ = sum;
324 ++nflops_;
325 }
326 else {
327 SafePtr<ExprType> sum(new ExprType(ExprType::OperatorTypes::Minus,expr_,expr));
328 expr_ = sum;
329 ++nflops_;
330 }
331 }
332 }
333
334
335 bool
invariant_type() const336 RecurrenceRelation::invariant_type() const {
337 // By default, recurrence relations do not change the type of the functions, i.e. VRR applied to an integral over shells will produce integrals over shells
338 return true;
339 }
340
341 std::string
spfunction_call(const SafePtr<CodeContext> & context,const SafePtr<ImplicitDimensions> & dims) const342 RecurrenceRelation::spfunction_call(const SafePtr<CodeContext>& context, const SafePtr<ImplicitDimensions>& dims) const
343 {
344 ostringstream os;
345 os << context->label_to_name(label_to_funcname(context->cparams()->api_prefix() + label()))
346 // First argument is the library object
347 << "(inteval, "
348 // Second is the target
349 << context->value_to_pointer(rr_target()->symbol());
350 // then come children
351 const unsigned int nchildren = num_children();
352 for(unsigned int c=0; c<nchildren; c++) {
353 os << ", " << context->value_to_pointer(rr_child(c)->symbol());
354 }
355 os << ")" << context->end_of_stat() << endl;
356 return os.str();
357 }
358
359 bool
has_generic(const SafePtr<CompilationParameters> & cparams) const360 RecurrenceRelation::has_generic(const SafePtr<CompilationParameters>& cparams) const {
361 return false;
362 }
363
364 std::string
generic_header() const365 RecurrenceRelation::generic_header() const {
366 throw std::logic_error("RecurrenceRelation::generic_header() -- should not be called! Check if DerivedRecurrenceRelation::generic_header() is implemented");
367 }
368
369 std::string
generic_instance(const SafePtr<CodeContext> & context,const SafePtr<CodeSymbols> & args) const370 RecurrenceRelation::generic_instance(const SafePtr<CodeContext>& context, const SafePtr<CodeSymbols>& args) const {
371 throw std::logic_error("RecurrenceRelation::generic_instance() -- should not be called! Check if DerivedRecurrenceRelation::generic_instance() is implemented");
372 }
373
374 size_t
size_of_children() const375 RecurrenceRelation::size_of_children() const {
376 const auto nchildren = this->num_children();
377 size_t result = 0;
378 for(auto c=0; c!=nchildren; ++c) {
379 result += this->rr_child(c)->size();
380 }
381 return result;
382 }
383
384 namespace libint2 { namespace algebra {
385 /// these operators are extremely useful to write compact expressions
operator +(const SafePtr<DGVertex> & A,const SafePtr<DGVertex> & B)386 SafePtr<RecurrenceRelation::ExprType> operator+(const SafePtr<DGVertex>& A,
387 const SafePtr<DGVertex>& B) {
388 typedef RecurrenceRelation::ExprType Oper;
389 return SafePtr<Oper>(new Oper(Oper::OperatorTypes::Plus,A,B));
390 }
operator -(const SafePtr<DGVertex> & A,const SafePtr<DGVertex> & B)391 SafePtr<RecurrenceRelation::ExprType> operator-(const SafePtr<DGVertex>& A,
392 const SafePtr<DGVertex>& B) {
393 typedef RecurrenceRelation::ExprType Oper;
394 return SafePtr<Oper>(new Oper(Oper::OperatorTypes::Minus,A,B));
395 }
operator *(const SafePtr<DGVertex> & A,const SafePtr<DGVertex> & B)396 SafePtr<RecurrenceRelation::ExprType> operator*(const SafePtr<DGVertex>& A,
397 const SafePtr<DGVertex>& B) {
398 typedef RecurrenceRelation::ExprType Oper;
399 return SafePtr<Oper>(new Oper(Oper::OperatorTypes::Times,A,B));
400 }
operator /(const SafePtr<DGVertex> & A,const SafePtr<DGVertex> & B)401 SafePtr<RecurrenceRelation::ExprType> operator/(const SafePtr<DGVertex>& A,
402 const SafePtr<DGVertex>& B) {
403 typedef RecurrenceRelation::ExprType Oper;
404 return SafePtr<Oper>(new Oper(Oper::OperatorTypes::Divide,A,B));
405 }
operator +=(SafePtr<RecurrenceRelation::ExprType> & A,const SafePtr<DGVertex> & B)406 const SafePtr<RecurrenceRelation::ExprType>& operator+=(SafePtr<RecurrenceRelation::ExprType>& A,
407 const SafePtr<DGVertex>& B) {
408 typedef RecurrenceRelation::ExprType Oper;
409 if (A) {
410 const SafePtr<Oper>& Sum = A + B;
411 A = Sum;
412 }
413 else
414 A = Scalar(0) + B;
415 return A;
416 }
operator -=(SafePtr<RecurrenceRelation::ExprType> & A,const SafePtr<DGVertex> & B)417 const SafePtr<RecurrenceRelation::ExprType>& operator-=(SafePtr<RecurrenceRelation::ExprType>& A,
418 const SafePtr<DGVertex>& B) {
419 typedef RecurrenceRelation::ExprType Oper;
420 if (A) {
421 const SafePtr<Oper>& Diff = A - B;
422 A = Diff;
423 }
424 else
425 A = Scalar(0) - B;
426 return A;
427 }
operator *=(SafePtr<RecurrenceRelation::ExprType> & A,const SafePtr<DGVertex> & B)428 const SafePtr<RecurrenceRelation::ExprType>& operator*=(SafePtr<RecurrenceRelation::ExprType>& A,
429 const SafePtr<DGVertex>& B) {
430 typedef RecurrenceRelation::ExprType Oper;
431 const SafePtr<Oper>& Product = A * B;
432 A = Product;
433 return A;
434 }
operator /=(SafePtr<RecurrenceRelation::ExprType> & A,const SafePtr<DGVertex> & B)435 const SafePtr<RecurrenceRelation::ExprType>& operator/=(SafePtr<RecurrenceRelation::ExprType>& A,
436 const SafePtr<DGVertex>& B) {
437 typedef RecurrenceRelation::ExprType Oper;
438 const SafePtr<Oper>& Quotient = A / B;
439 A = Quotient;
440 return A;
441 }
442 } } // namespace libint2::algebra
443
444 ///////////////
445