1 /*
2 * Copyright (C) 2004-2021 Edward F. Valeev
3 *
4 * This file is part of Libint.
5 *
6 * Libint is free software: you can redistribute it and/or modify
7 * it under the terms of the GNU General Public License as published by
8 * the Free Software Foundation, either version 3 of the License, or
9 * (at your option) any later version.
10 *
11 * Libint is distributed in the hope that it will be useful,
12 * but WITHOUT ANY WARRANTY; without even the implied warranty of
13 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14 * GNU General Public License for more details.
15 *
16 * You should have received a copy of the GNU General Public License
17 * along with Libint. If not, see <http://www.gnu.org/licenses/>.
18 *
19 */
20
21 #include <cassert>
22 #include <cstdio>
23 #include <context.h>
24 #include <codeblock.h>
25 #include <default_params.h>
26
27 using namespace std;
28 using namespace libint2;
29
30 namespace libint2 {
31
32 template <>
unique_name() const33 std::string CodeContext::unique_name<EntityTypes::FP>() const
34 {
35 return unique_fp_name();
36 }
37 template <>
unique_name() const38 std::string CodeContext::unique_name<EntityTypes::Int>() const
39 {
40 return unique_int_name();
41 }
42
43 template <>
type_name() const44 std::string CodeContext::type_name<void>() const
45 {
46 return void_type();
47 }
48 template <>
type_name() const49 std::string CodeContext::type_name<int>() const
50 {
51 return int_type();
52 }
53 template <>
type_name() const54 std::string CodeContext::type_name<size_t>() const
55 {
56 return size_type();
57 }
58 template <>
type_name() const59 std::string CodeContext::type_name<const int>() const
60 {
61 return const_modifier() + int_type();
62 }
63 template <>
type_name() const64 std::string CodeContext::type_name<double>() const
65 {
66 return fp_type();
67 }
68 template <>
type_name() const69 std::string CodeContext::type_name<double*>() const
70 {
71 return ptr_fp_type();
72 }
73 template <>
type_name() const74 std::string CodeContext::type_name<const double*>() const
75 {
76 return const_modifier() + ptr_fp_type();
77 }
78 template <>
type_name() const79 std::string CodeContext::type_name<double* const>() const
80 {
81 return ptr_fp_type() + const_modifier();
82 }
83 };
84
CodeContext(const SafePtr<CompilationParameters> & cparams)85 CodeContext::CodeContext(const SafePtr<CompilationParameters>& cparams) :
86 cparams_(cparams),
87 comments_on_(false)
88 {
89 zero_out_counters();
90 }
91
92 const SafePtr<CompilationParameters>&
cparams() const93 CodeContext::cparams() const
94 {
95 return cparams_;
96 }
97
98 bool
comments_on() const99 CodeContext::comments_on() const { return comments_on_; }
100
101 unsigned int
next_fp_index() const102 CodeContext::next_fp_index() const
103 {
104 return next_index_[EntityTypes::FP::type2int()]++;
105 }
106
107 unsigned int
next_int_index() const108 CodeContext::next_int_index() const
109 {
110 return next_index_[EntityTypes::Int::type2int()]++;
111 }
112
113 void
zero_out_counters() const114 CodeContext::zero_out_counters() const
115 {
116 for(unsigned int i=0; i<EntityTypes::ntypes; i++)
117 next_index_[i] = 0;
118 }
119
120 void
reset()121 CodeContext::reset()
122 {
123 zero_out_counters();
124 }
125
126 std::string
replace_chars(const std::string & S,const std::string & From,const std::string & To)127 CodeContext::replace_chars(const std::string& S, const std::string& From, const std::string& To)
128 {
129 typedef std::string::size_type size_type;
130
131 const unsigned int max_niter = 1000;
132 unsigned int niter = 0;
133 std::string curr_str(S);
134 size_type curr_pos = curr_str.find(From,0);
135 while (curr_pos != std::string::npos) {
136 niter++;
137 curr_str.replace(curr_pos,From.length(),To,0,To.length());
138 curr_pos += To.length() - From.length();
139 curr_pos = curr_str.find(From,curr_pos);
140 if (niter >= max_niter)
141 throw std::runtime_error("CodeContext::replace_chars() -- infinite recursion detected");
142 }
143 return curr_str;
144 }
145
146 //////////////
147
148 namespace ForbiddenCppCharacters {
149 static const unsigned int nchars = 16;
150 static const char chars[nchars][2] = {
151 "{",
152 "}",
153 "(",
154 ")",
155 " ",
156 "+",
157 "-",
158 "/",
159 "*",
160 "|",
161 "^",
162 "[",
163 "]",
164 ",",
165 "<",
166 ">"
167 };
168 static const char subst_chars[nchars][20] = {
169 "",
170 "",
171 "__",
172 "__",
173 "",
174 "_plus_",
175 "_minus_",
176 "_over_",
177 "_times_",
178 "_",
179 "_up_",
180 "_sB_",
181 "_Sb_",
182 "_c_",
183 "_aB_",
184 "_Ab_"
185 };
186 };
187
CppCodeContext(const SafePtr<CompilationParameters> & cparams,bool vectorize)188 CppCodeContext::CppCodeContext(const SafePtr<CompilationParameters>& cparams, bool vectorize) :
189 CodeContext(cparams), vectorize_(vectorize)
190 {
191 }
192
~CppCodeContext()193 CppCodeContext::~CppCodeContext()
194 {
195 }
196
197 std::string
code_prefix() const198 CppCodeContext::code_prefix() const
199 {
200 if (cparams()->use_C_linking()) {
201 return "#ifdef __cplusplus\nLIBINT_PRAGMA_CLANG(diagnostic push)\nLIBINT_PRAGMA_CLANG(diagnostic ignored \"-Wunused-variable\")\nLIBINT_PRAGMA_GCC(diagnostic push)\nLIBINT_PRAGMA_GCC(diagnostic ignored \"-Wunused-variable\")\nextern \"C\" {\n#endif\n";
202 }
203 return "";
204 }
205
206 std::string
code_postfix() const207 CppCodeContext::code_postfix() const
208 {
209 if (cparams()->use_C_linking()) {
210 return "#ifdef __cplusplus\n};\nLIBINT_PRAGMA_CLANG(diagnostic pop)\nLIBINT_PRAGMA_GCC(diagnostic pop)\n#endif\n";
211 }
212 return "";
213 }
214
215 std::string
copyright() const216 CppCodeContext::copyright() const {
217 std::ostringstream oss;
218 using std::endl;
219 oss << "/*"<< endl
220 << " * Copyright (C) 2004-2021 Edward F. Valeev" << endl
221 << " *" << endl
222 << " * This file is part of Libint." << endl
223 << " *" << endl
224 << " * Libint is free software: you can redistribute it and/or modify" << endl
225 << " * it under the terms of the GNU Lesser General Public License as published by" << endl
226 << " * the Free Software Foundation, either version 3 of the License, or" << endl
227 << " * (at your option) any later version." <<endl
228 << " *" << endl
229 << " * Libint is distributed in the hope that it will be useful," << endl
230 << " * but WITHOUT ANY WARRANTY; without even the implied warranty of" << endl
231 << " * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the" << endl
232 << " * GNU Lesser General Public License for more details." << endl
233 << " *" << endl
234 << " * You should have received a copy of the GNU Lesser General Public License" << endl
235 << " * along with Libint. If not, see <http://www.gnu.org/licenses/>." << endl
236 << " *" << endl
237 << " */" << endl
238 << endl;
239 return oss.str();
240 }
241
242 std::string
std_header() const243 CppCodeContext::std_header() const
244 {
245 std::string result("#include <libint2.h>\n");
246 return result;
247 }
248
249 std::string
std_function_header() const250 CppCodeContext::std_function_header() const
251 {
252 ostringstream oss;
253 if(vectorize_) {
254 oss << "const int veclen = inteval->veclen;\n";
255 }
256 return oss.str();
257 }
258
259 std::string
label_to_name(const std::string & label) const260 CppCodeContext::label_to_name(const std::string& label) const
261 {
262 std::string str = label;
263 for(unsigned int c=0; c<ForbiddenCppCharacters::nchars; c++) {
264 str = replace_chars(str,ForbiddenCppCharacters::chars[c],ForbiddenCppCharacters::subst_chars[c]);
265 }
266 return str;
267 }
268
269 std::string
declare(const std::string & type,const std::string & name) const270 CppCodeContext::declare(const std::string& type,
271 const std::string& name) const
272 {
273 ostringstream oss;
274
275 oss << type << " " << name << end_of_stat() << endl;
276
277 return oss.str();
278 }
279
280 std::string
declare_v(const std::string & type,const std::string & name,const std::string & nelem) const281 CppCodeContext::declare_v(const std::string& type,
282 const std::string& name,
283 const std::string& nelem) const
284 {
285 ostringstream oss;
286
287 oss << type << " " << name << "[" << nelem << "]" << end_of_stat() << endl;
288
289 return oss.str();
290 }
291
292 std::string
decldef(const std::string & type,const std::string & name,const std::string & value)293 CppCodeContext::decldef(const std::string& type,
294 const std::string& name,
295 const std::string& value)
296 {
297 ostringstream oss;
298
299 oss << type << " " << assign(name,value);
300
301 return oss.str();
302 }
303
304 std::string
assign(const std::string & name,const std::string & value)305 CppCodeContext::assign(const std::string& name,
306 const std::string& value)
307 {
308 return assign_(name,value,false);
309 }
310
311 std::string
accumulate(const std::string & name,const std::string & value)312 CppCodeContext::accumulate(const std::string& name,
313 const std::string& value)
314 {
315 return assign_(name,value,true);
316 }
317
318 std::string
assign_(const std::string & name,const std::string & value,bool accum)319 CppCodeContext::assign_(const std::string& name,
320 const std::string& value,
321 bool accum)
322 {
323 ostringstream oss;
324
325 if (vectorize_) {
326 std::string symb0 = unique_fp_name();
327 std::string symb1 = unique_fp_name();
328 std::string ptr0 = symbol_to_pointer(name);
329 std::string ptr1 = symbol_to_pointer(value);
330 bool symb1_is_a_const = (ptr1.length() == 0);
331 oss << "LIBINT2_REALTYPE* " << symb0 << " = "
332 << symbol_to_pointer(name) << end_of_stat() << endl;
333 oss << "__assume_aligned(" << symb0 << ", 16)" << end_of_stat() << endl;
334 if (!symb1_is_a_const) {
335 oss << "LIBINT2_REALTYPE* " << symb1 << " = "
336 << symbol_to_pointer(value) << end_of_stat() << endl;
337 oss << "__assume_aligned(" << symb1 << ", 16)" << end_of_stat() << endl;
338 }
339
340 oss << start_expr();
341 oss << symb0 << "[v]" << (accum ? " += " : " = ")
342 << (symb1_is_a_const ? value : symb1)
343 << (symb1_is_a_const ? " " : "[v] ");
344 }
345 else {
346 oss << start_expr();
347 oss << name << (accum ? " += " : " = ") << value;
348 }
349 oss << end_of_stat() << endl;
350 oss << end_expr();
351
352 return oss.str();
353 }
354
355 std::string
assign_binary_expr(const std::string & name,const std::string & left,const std::string & oper,const std::string & right)356 CppCodeContext::assign_binary_expr(const std::string& name,
357 const std::string& left,
358 const std::string& oper,
359 const std::string& right)
360 {
361 return assign_binary_expr_(name,left,oper,right,false);
362 }
363
364 std::string
accumulate_binary_expr(const std::string & name,const std::string & left,const std::string & oper,const std::string & right)365 CppCodeContext::accumulate_binary_expr(const std::string& name,
366 const std::string& left,
367 const std::string& oper,
368 const std::string& right)
369 {
370 return assign_binary_expr_(name,left,oper,right,true);
371 }
372
373 std::string
assign_binary_expr_(const std::string & name,const std::string & left,const std::string & oper,const std::string & right,bool accum)374 CppCodeContext::assign_binary_expr_(const std::string& name,
375 const std::string& left,
376 const std::string& oper,
377 const std::string& right,
378 bool accum)
379 {
380 ostringstream oss;
381
382 if (vectorize_) {
383 std::string symb0 = unique_fp_name();
384 std::string symb1 = unique_fp_name();
385 std::string symb2 = unique_fp_name();
386 std::string ptr0 = symbol_to_pointer(name);
387 std::string ptr1 = symbol_to_pointer(left);
388 std::string ptr2 = symbol_to_pointer(right);
389 bool symb1_is_a_const = (ptr1.length() == 0);
390 bool symb2_is_a_const = (ptr2.length() == 0);
391 oss << "LIBINT2_REALTYPE* " << symb0 << " = "
392 << symbol_to_pointer(name) << end_of_stat() << endl;
393 oss << "__assume_aligned(" << symb0 << ", 16)" << end_of_stat() << endl;
394 if (!symb1_is_a_const) {
395 oss << "LIBINT2_REALTYPE* " << symb1 << " = "
396 << symbol_to_pointer(left) << end_of_stat() << endl;
397 oss << "__assume_aligned(" << symb1 << ", 16)" << end_of_stat() << endl;
398 }
399 if (!symb2_is_a_const) {
400 oss << "LIBINT2_REALTYPE* " << symb2 << " = "
401 << symbol_to_pointer(right) << end_of_stat() << endl;
402 oss << "__assume_aligned(" << symb2 << ", 16)" << end_of_stat() << endl;
403 }
404
405 oss << start_expr();
406 oss << symb0 << "[v]" << (accum ? " += " : " = ")
407 << (symb1_is_a_const ? left : symb1)
408 << (symb1_is_a_const ? " " : "[v] ")
409 << oper << " "
410 << (symb2_is_a_const ? right : symb2)
411 << (symb2_is_a_const ? "" : "[v]");
412 }
413 else {
414 oss << start_expr();
415 oss << name << (accum ? " += " : " = ") << left << " "
416 << oper << " " << right;
417 }
418 oss << end_of_stat() << endl;
419 oss << end_expr();
420
421 return oss.str();
422 }
423
424 std::string
assign_ternary_expr(const std::string & name,const std::string & arg1,const std::string & oper1,const std::string & arg2,const std::string & oper2,const std::string & arg3)425 CppCodeContext::assign_ternary_expr(const std::string& name,
426 const std::string& arg1,
427 const std::string& oper1,
428 const std::string& arg2,
429 const std::string& oper2,
430 const std::string& arg3) {
431 return assign_ternary_expr_(name, arg1, oper1, arg2, oper2, arg3, false);
432 }
433
434 std::string
assign_ternary_expr_(const std::string & name,const std::string & arg1,const std::string & oper1,const std::string & arg2,const std::string & oper2,const std::string & arg3,bool accum)435 CppCodeContext::assign_ternary_expr_(const std::string& name,
436 const std::string& arg1,
437 const std::string& oper1,
438 const std::string& arg2,
439 const std::string& oper2,
440 const std::string& arg3,
441 bool accum)
442 {
443 ostringstream oss;
444
445 // this should only be invoked for FMA, i.e. oper1 = "*" and oper2 = "+" or "-"
446 assert(oper1 == "*");
447 assert(oper2 == "+" || oper2 == "-");
448
449 if (vectorize_) {
450 std::string symb0 = unique_fp_name();
451 std::string symb1 = unique_fp_name();
452 std::string symb2 = unique_fp_name();
453 std::string symb3 = unique_fp_name();
454 std::string ptr0 = symbol_to_pointer(name);
455 std::string ptr1 = symbol_to_pointer(arg1);
456 std::string ptr2 = symbol_to_pointer(arg2);
457 std::string ptr3 = symbol_to_pointer(arg3);
458 bool symb1_is_a_const = (ptr1.length() == 0);
459 bool symb2_is_a_const = (ptr2.length() == 0);
460 bool symb3_is_a_const = (ptr3.length() == 0);
461 oss << "LIBINT2_REALTYPE* " << symb0 << " = "
462 << symbol_to_pointer(name) << end_of_stat() << endl;
463 oss << "__assume_aligned(" << symb0 << ", 16)" << end_of_stat() << endl;
464 if (!symb1_is_a_const) {
465 oss << "LIBINT2_REALTYPE* " << symb1 << " = "
466 << symbol_to_pointer(arg1) << end_of_stat() << endl;
467 oss << "__assume_aligned(" << symb1 << ", 16)" << end_of_stat() << endl;
468 }
469 if (!symb2_is_a_const) {
470 oss << "LIBINT2_REALTYPE* " << symb2 << " = "
471 << symbol_to_pointer(arg2) << end_of_stat() << endl;
472 oss << "__assume_aligned(" << symb2 << ", 16)" << end_of_stat() << endl;
473 }
474 if (!symb3_is_a_const) {
475 oss << "LIBINT2_REALTYPE* " << symb3 << " = "
476 << symbol_to_pointer(arg3) << end_of_stat() << endl;
477 oss << "__assume_aligned(" << symb3 << ", 16)" << end_of_stat() << endl;
478 }
479
480 oss << start_expr();
481 oss << symb0 << "[v]" << (accum ? " += " : " = ");
482 #if LIBINT_GENERATE_FMA
483 oss << "libint2::fma_" << (oper2 == "+" ? "plus" : "minus") << "("
484 << (symb1_is_a_const ? arg1 : symb1)
485 << (symb1_is_a_const ? " " : "[v] ")
486 << ","
487 << (symb2_is_a_const ? arg2 : symb2)
488 << (symb2_is_a_const ? "" : "[v]")
489 << ","
490 << (symb3_is_a_const ? arg3 : symb3)
491 << (symb3_is_a_const ? "" : "[v]")
492 << ")";
493 #else
494 oss << (symb1_is_a_const ? arg1 : symb1)
495 << (symb1_is_a_const ? " " : "[v] ")
496 << "* "
497 << (symb2_is_a_const ? arg2 : symb2)
498 << (symb2_is_a_const ? " " : "[v] ")
499 << oper2
500 << (symb3_is_a_const ? arg3 : symb3)
501 << (symb3_is_a_const ? "" : "[v]");
502 #endif
503 }
504 else {
505 oss << start_expr();
506 oss << name << (accum ? " += " : " = ");
507 #if LIBINT_GENERATE_FMA
508 oss << "libint2::fma_" << (oper2 == "+" ? "plus" : "minus") << "("
509 << arg1 << ", " << arg2 << ", " << arg3 << ")";
510 #else
511 oss << arg1 << " * " << arg2 << " " << oper2 << " " << arg3;
512 #endif
513 }
514 oss << end_of_stat() << endl;
515 oss << end_expr();
516
517 return oss.str();
518 }
519
520 std::string
accumulate_ternary_expr(const std::string & name,const std::string & arg1,const std::string & oper1,const std::string & arg2,const std::string & oper2,const std::string & arg3)521 CppCodeContext::accumulate_ternary_expr(const std::string& name,
522 const std::string& arg1,
523 const std::string& oper1,
524 const std::string& arg2,
525 const std::string& oper2,
526 const std::string& arg3) {
527 return assign_ternary_expr_(name, arg1, oper1, arg2, oper2, arg3, true);
528 }
529
530 std::string
symbol_to_pointer(const std::string & symbol)531 CppCodeContext::symbol_to_pointer(const std::string& symbol)
532 {
533 std::string::size_type loc = symbol.find("stack");
534 // if this quantity is on stack then the symbol is a scalar
535 if (loc != std::string::npos) {
536 ostringstream oss;
537 oss << "(&(" << symbol << "))";
538 return oss.str();
539 }
540
541 // if this quantity is a part of Libint_t then the symbol is a vector
542 // otherwise it's a constant
543 loc = symbol.find("inteval");
544 if (loc != std::string::npos)
545 return symbol;
546 else
547 return "";
548 }
549
550 std::string
start_expr() const551 CppCodeContext::start_expr() const
552 {
553 if (vectorize_)
554 return "#ifdef __INTEL_COMPILER\n#pragma ivdep\n#endif\nfor(int v=0; v<veclen; v++) {\n";
555 else
556 return "";
557 }
558
559
560 std::string
end_expr() const561 CppCodeContext::end_expr() const
562 {
563 if (vectorize_)
564 return "}\n";
565 else
566 return "";
567 }
568
569
570 std::string
stack_address(const DGVertex::Address & a) const571 CppCodeContext::stack_address(const DGVertex::Address& a) const
572 {
573 ostringstream oss;
574 if (vectorize_)
575 oss << "(" << a << ")*veclen";
576 else
577 oss << a;
578 return oss.str();
579 }
580
581 std::string
macro_define(const std::string & name) const582 CppCodeContext::macro_define(const std::string& name) const
583 {
584 ostringstream oss;
585 oss << "#define " << name << endl;
586 return oss.str();
587 }
588
589 std::string
macro_define(const std::string & name,const std::string & value) const590 CppCodeContext::macro_define(const std::string& name,
591 const std::string& value) const
592 {
593 ostringstream oss;
594 oss << "#define " << name << " " << value << endl;
595 return oss.str();
596 }
597
598 std::string
macro_if(const std::string & name) const599 CppCodeContext::macro_if(const std::string& name) const
600 {
601 ostringstream oss;
602 oss << "#if " << name << endl;
603 return oss.str();
604 }
605
606 std::string
macro_ifdef(const std::string & name) const607 CppCodeContext::macro_ifdef(const std::string& name) const
608 {
609 ostringstream oss;
610 oss << "#ifdef " << name << endl;
611 return oss.str();
612 }
613
614 std::string
macro_endif() const615 CppCodeContext::macro_endif() const
616 {
617 ostringstream oss;
618 oss << "#endif" << endl;
619 return oss.str();
620 }
621
622 std::string
comment(const std::string & statement) const623 CppCodeContext::comment(const std::string& statement) const
624 {
625 std::string result("/** ");
626 result += statement;
627 result += " */";
628 return result;
629 }
630
631 std::string
open_block() const632 CppCodeContext::open_block() const
633 {
634 return " {\n";
635 }
636
637 std::string
close_block() const638 CppCodeContext::close_block() const
639 {
640 return "}\n";
641 }
642
643 std::string
end_of_stat() const644 CppCodeContext::end_of_stat() const
645 {
646 static const std::string ends(";");
647 return ends;
648 }
649
650 std::string
value_to_pointer(const std::string & val) const651 CppCodeContext::value_to_pointer(const std::string& val) const
652 {
653 if (!vectorize_) {
654 std::string ptr("&(");
655 ptr += val; ptr += ")";
656 return ptr;
657 }
658 else {
659 return val;
660 }
661 }
662
663 SafePtr<ForLoop>
for_loop(std::string & varname,const SafePtr<Entity> & less_than,const SafePtr<Entity> & start_at) const664 CppCodeContext::for_loop(std::string& varname, const SafePtr<Entity>& less_than,
665 const SafePtr<Entity>& start_at) const
666 {
667 // not implemented
668 abort();
669 }
670
671 std::string
unique_fp_name() const672 CppCodeContext::unique_fp_name() const
673 {
674 char result[80];
675 sprintf(result,"fp%d", next_fp_index());
676 return result;
677 }
678
679 std::string
unique_int_name() const680 CppCodeContext::unique_int_name() const
681 {
682 char result[80];
683 sprintf(result,"i%d", next_int_index());
684 return result;
685 }
686
687 std::string
void_type() const688 CppCodeContext::void_type() const { return "void"; }
689 std::string
int_type() const690 CppCodeContext::int_type() const { return "int"; }
691 std::string
size_type() const692 CppCodeContext::size_type() const { return "size_t"; }
693 std::string
fp_type() const694 CppCodeContext::fp_type() const
695 {
696 if (!vectorize_)
697 return "LIBINT2_REALTYPE";
698 else
699 return ptr_fp_type();
700 }
701 std::string
ptr_fp_type() const702 CppCodeContext::ptr_fp_type() const { return "LIBINT2_REALTYPE*"; }
703 std::string
const_modifier() const704 CppCodeContext::const_modifier() const { return "const "; }
705 std::string
mutable_modifier() const706 CppCodeContext::mutable_modifier() const { return "#ifdef __cplusplus\nmutable \n#endif\n"; }
707
708 std::string
inteval_type_name(const std::string & tlabel) const709 CppCodeContext::inteval_type_name(const std::string& tlabel) const
710 {
711 if (cparams()->single_evaltype())
712 return inteval_gen_type_name();
713 else
714 return inteval_spec_type_name(tlabel);
715 }
716
717 std::string
inteval_spec_type_name(const std::string & tlabel) const718 CppCodeContext::inteval_spec_type_name(const std::string& tlabel) const
719 {
720 ostringstream oss;
721 oss << "Libint_" << tlabel << "_t";
722 return oss.str();
723 }
724
725 std::string
inteval_gen_type_name() const726 CppCodeContext::inteval_gen_type_name() const
727 {
728 return "Libint_t";
729 }
730