1 /*
2 * This source code is part of
3 *
4 * E R K A L E
5 * -
6 * DFT from Hel
7 *
8 * Written by Susi Lehtola, 2010-2013
9 * Copyright (c) 2010-2013, Susi Lehtola
10 *
11 * This program is free software; you can redistribute it and/or
12 * modify it under the terms of the GNU General Public License
13 * as published by the Free Software Foundation; either version 2
14 * of the License, or (at your option) any later version.
15 */
16
17 #include "solidharmonics.h"
18 #include "eriworker.h"
19
main(void)20 int main(void) {
21 enum {i, j, k, l} trans;
22
23 printf("#include \"eriworker.h\"\n");
24 printf("#include \"solidharmonics.h\"\n");
25 printf("#include \"mathf.h\"\n\n");
26
27 for(int it=0;it<4;it++) {
28 switch(it) {
29 case(0):
30 trans=i;
31 break;
32 case(1):
33 trans=j;
34 break;
35 case(2):
36 trans=k;
37 break;
38 case(3):
39 trans=l;
40 break;
41 }
42
43 std::string arg1, arg2, arg3;
44
45 switch(trans) {
46 case(i):
47 arg1="Nj";
48 arg2="Nk";
49 arg3="Nl";
50 break;
51 case(j):
52 arg1="Ni";
53 arg2="Nk";
54 arg3="Nl";
55 break;
56 case(k):
57 arg1="Ni";
58 arg2="Nj";
59 arg3="Nl";
60 break;
61 case(l):
62 arg1="Ni";
63 arg2="Nj";
64 arg3="Nk";
65 break;
66 }
67
68 const char ijkl[]="ijkl";
69
70 // Individual transforms
71 for(int am=0;am<LIBINT_MAX_AM;am++) {
72 // Get transformation matrix
73 arma::mat transmat=Ylm_transmat(am);
74 size_t Nsph=transmat.n_rows;
75 size_t Ncart=transmat.n_cols;
76
77 // Print function header
78 printf("static void transform_%c%i(size_t %s, size_t %s, size_t %s, const std::vector<double> *input, std::vector<double> *output) {\n",ijkl[it],am,arg1.c_str(),arg2.c_str(),arg3.c_str());
79 printf(" (*output).clear();\n");
80 printf(" (*output).resize(%i*%s*%s*%s,0.0);\n",(int) Nsph,arg1.c_str(),arg2.c_str(),arg3.c_str());
81
82 // Transform loops
83 switch(trans) {
84 case(i):
85 printf(" for(size_t jj=0;jj<Nj;jj++)\n");
86 printf(" for(size_t kk=0;kk<Nk;kk++)\n");
87 printf(" for(size_t ll=0;ll<Nl;ll++) {\n");
88 for(size_t iin=0;iin<Ncart;iin++)
89 for(size_t iout=0;iout<Nsph;iout++)
90 if(transmat(iout,iin)!=0.0)
91 printf(" (*output)[((%2i*Nj+jj)*Nk+kk)*Nl+ll] += % .16e * (*input)[((%2i*Nj+jj)*Nk+kk)*Nl+ll];\n",(int) iout,transmat(iout,iin),(int) iin);
92 break;
93
94 case(j):
95 printf(" for(size_t ii=0;ii<Ni;ii++)\n");
96 printf(" for(size_t kk=0;kk<Nk;kk++)\n");
97 printf(" for(size_t ll=0;ll<Nl;ll++) {\n");
98 for(size_t jin=0;jin<Ncart;jin++)
99 for(size_t jout=0;jout<Nsph;jout++)
100 if(transmat(jout,jin)!=0.0)
101 printf(" (*output)[((ii*%2i+%2i)*Nk+kk)*Nl+ll] += % .16e * (*input)[((ii*%2i+%2i)*Nk+kk)*Nl+ll];\n",(int) Nsph,(int) jout,transmat(jout,jin),(int) Ncart,(int) jin);
102 break;
103
104 case(k):
105 printf(" for(size_t ii=0;ii<Ni;ii++)\n");
106 printf(" for(size_t jj=0;jj<Nj;jj++)\n");
107 printf(" for(size_t ll=0;ll<Nl;ll++) {\n");
108 for(size_t kin=0;kin<Ncart;kin++)
109 for(size_t kout=0;kout<Nsph;kout++)
110 if(transmat(kout,kin)!=0.0)
111 printf(" (*output)[((ii*Nj+jj)*%2i+%2i)*Nl+ll] += % .16e * (*input)[((ii*Nj+jj)*%2i+%2i)*Nl+ll];\n",(int) Nsph,(int) kout,transmat(kout,kin),(int) Ncart,(int) kin);
112 break;
113
114 case(l):
115 printf(" for(size_t ii=0;ii<Ni;ii++)\n");
116 printf(" for(size_t jj=0;jj<Nj;jj++)\n");
117 printf(" for(size_t kk=0;kk<Nk;kk++) {\n");
118 for(size_t lin=0;lin<Ncart;lin++)
119 for(size_t lout=0;lout<Nsph;lout++)
120 if(transmat(lout,lin)!=0.0)
121 printf(" (*output)[((ii*Nj+jj)*Nk+kk)*%2i+%2i] += % .16e * (*input)[((ii*Nj+jj)*Nk+kk)*%2i+%2i];\n",(int) Nsph,(int) lout,transmat(lout,lin),(int) Ncart,(int) lin);
122 break;
123 }
124
125 printf(" }\n");
126 printf("}\n\n");
127 }
128
129 /// Main driver
130 printf("void IntegralWorker::transform_%c(int am, size_t %s, size_t %s, size_t %s) {\n",ijkl[it],arg1.c_str(),arg2.c_str(),arg3.c_str());
131 // Table of drivers
132 printf(" static void (*f[%i])(size_t, size_t, size_t, const std::vector<double> *, std::vector<double> *)={\n",LIBINT_MAX_AM);
133 for(int am=0;am<LIBINT_MAX_AM;am++) {
134 printf(" transform_%c%i",ijkl[it],am);
135 if(am<LIBINT_MAX_AM-1)
136 printf(",");
137 printf("\n");
138 }
139 printf(" };\n");
140
141 // Call driver
142 printf(" f[am](%s,%s,%s,input,output);\n",arg1.c_str(),arg2.c_str(),arg3.c_str());
143
144 // Swap arrays
145 printf(" std::swap(input,output);\n");
146 printf("}\n");
147
148 }
149 return 0;
150 }
151