1 /*  Copyright 2014-2020 The PySCF Developers. All Rights Reserved.
2 
3     Licensed under the Apache License, Version 2.0 (the "License");
4     you may not use this file except in compliance with the License.
5     You may obtain a copy of the License at
6 
7         http://www.apache.org/licenses/LICENSE-2.0
8 
9     Unless required by applicable law or agreed to in writing, software
10     distributed under the License is distributed on an "AS IS" BASIS,
11     WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12     See the License for the specific language governing permissions and
13     limitations under the License.
14 
15  *
16  *  Author: Oliver J. Backhouse <olbackhouse@gmail.com>
17  *          Alejandro Santana-Bonilla <alejandro.santana_bonilla@kcl.ac.uk>
18  *          George H. Booth <george.booth@kcl.ac.uk>
19  */
20 
21 #include<stdlib.h>
22 #include<stdbool.h>
23 #include<assert.h>
24 #include<math.h>
25 #include<stdio.h>
26 
27 //#include "omp.h"
28 #include "config.h"
29 #include "vhf/fblas.h"
30 #include "ragf2.h"
31 
32 
33 
34 /*
35  *  Capital indices indicate the opposite spin to the lower case index
36  */
37 
38 /*
39  *  exact ERI
40  *  vv_xy = (xi|ja) [(yi|ja) + (yi|JA) - (yi|ja)]
41  *  vev_xy = (xi|ja) [(yi|ja) - (yj|ia)] (ei + ej - ea) + (xi|ja) (yi|JA) (ei + eJ - eA)
42  */
AGF2uee_vv_vev_islice(double * xija,double * xiJA,double * e_i,double * e_I,double * e_a,double * e_A,double os_factor,double ss_factor,int nmo,int noa,int nob,int nva,int nvb,int istart,int iend,double * vv,double * vev)43 void AGF2uee_vv_vev_islice(double *xija,
44                            double *xiJA,
45                            double *e_i,
46                            double *e_I,
47                            double *e_a,
48                            double *e_A,
49                            double os_factor,
50                            double ss_factor,
51                            int nmo,
52                            int noa,
53                            int nob,
54                            int nva,
55                            int nvb,
56                            int istart,
57                            int iend,
58                            double *vv,
59                            double *vev)
60 {
61     const double D1 = 1.0;
62     const char TRANS_T = 'T';
63     const char TRANS_N = 'N';
64 
65     const int nja = noa * nva;
66     const int nJA = nob * nvb;
67     const int nxi = nmo * noa;
68 
69 #pragma omp parallel
70 {
71     double *eja = calloc(noa*nva, sizeof(double));
72     double *eJA = calloc(nob*nvb, sizeof(double));
73     double *xia = calloc(nmo*noa*nva, sizeof(double));
74     double *xja = calloc(nmo*noa*nva, sizeof(double));
75     double *xJA = calloc(nmo*nob*nvb, sizeof(double));
76     double *exJA = calloc(nmo*nob*nvb, sizeof(double));
77 
78     double *vv_priv = calloc(nmo*nmo, sizeof(double));
79     double *vev_priv = calloc(nmo*nmo, sizeof(double));
80 
81     int i;
82 
83 #pragma omp for
84     for (i = istart; i < iend; i++) {
85         // build xija
86         AGF2slice_0i2(xija, nmo, noa, nja, i, xja);
87 
88         // build xiJA
89         AGF2slice_0i2(xiJA, nmo, noa, nJA, i, xJA);
90 
91         // build xjia
92         AGF2slice_0i2(xija, nxi, noa, nva, i, xia);
93 
94         // build eija = ei + ej - ea
95         AGF2sum_inplace_ener(e_i[i], e_i, e_a, noa, nva, eja);
96 
97         // build eiJA = ei + eJ - eA
98         AGF2sum_inplace_ener(e_i[i], e_I, e_A, nob, nvb, eJA);
99 
100         // inplace xjia = xija - xjia
101         AGF2sum_inplace(xja, xia, nmo*nja, ss_factor, -ss_factor);
102 
103         // vv_xy += xija * (yija - yjia)
104         dgemm_(&TRANS_T, &TRANS_N, &nmo, &nmo, &nja, &D1, xia, &nja, xja, &nja, &D1, vv_priv, &nmo);
105 
106         // vv_xy += xiJA * yiJA
107         dgemm_(&TRANS_T, &TRANS_N, &nmo, &nmo, &nJA, &os_factor, xJA, &nJA, xJA, &nJA, &D1, vv_priv, &nmo);
108 
109         // inplace xija = eija * xija
110         AGF2prod_inplace_ener(eja, xja, nmo, nja);
111 
112         // outplace xiJA = eiJA * xiJA
113         AGF2prod_outplace_ener(eJA, xJA, nmo, nJA, exJA);
114 
115         // vev_xy += xija * eija * (yija - yjia)
116         dgemm_(&TRANS_T, &TRANS_N, &nmo, &nmo, &nja, &D1, xia, &nja, xja, &nja, &D1, vev_priv, &nmo);
117 
118         // vev_xy += xiJA * eiJA * yiJA
119         dgemm_(&TRANS_T, &TRANS_N, &nmo, &nmo, &nJA, &os_factor, xJA, &nJA, exJA, &nJA, &D1, vev_priv, &nmo);
120     }
121 
122     free(eja);
123     free(eJA);
124     free(xia);
125     free(xja);
126     free(xJA);
127     free(exJA);
128 
129 #pragma omp critical
130     for (i = 0; i < (nmo*nmo); i++) {
131         vv[i] += vv_priv[i];
132         vev[i] += vev_priv[i];
133     }
134 
135     free(vv_priv);
136     free(vev_priv);
137 }
138 }
139 
140 
141 /*
142  *  density fitting
143  *  (xi|ja) = (xi|Q)(Q|ja)
144  *  vv_xy = (xi|ja) [(yi|ja) + (yi|JA) - (yi|ja)]
145  *  vev_xy = (xi|ja) [(yi|ja) - (yj|ia)] (ei + ej - ea) + (xi|ja) (yi|JA) (ei + eJ - eA)
146  */
AGF2udf_vv_vev_islice(double * qxi,double * qja,double * qJA,double * e_i,double * e_I,double * e_a,double * e_A,double os_factor,double ss_factor,int nmo,int noa,int nob,int nva,int nvb,int naux,int istart,int iend,double * vv,double * vev)147 void AGF2udf_vv_vev_islice(double *qxi,
148                            double *qja,
149                            double *qJA,
150                            double *e_i,
151                            double *e_I,
152                            double *e_a,
153                            double *e_A,
154                            double os_factor,
155                            double ss_factor,
156                            int nmo,
157                            int noa,
158                            int nob,
159                            int nva,
160                            int nvb,
161                            int naux,
162                            int istart,
163                            int iend,
164                            double *vv,
165                            double *vev)
166 {
167     const double D0 = 0.0;
168     const double D1 = 1.0;
169     const char TRANS_T = 'T';
170     const char TRANS_N = 'N';
171 
172     const int nxi = nmo * noa;
173     const int nja = noa * nva;
174     const int nJA = nob * nvb;
175 
176 #pragma omp parallel
177 {
178     double *qa = calloc(naux*nva, sizeof(double));
179     double *qx = calloc(naux*nmo, sizeof(double));
180     double *eja = calloc(noa*nva, sizeof(double));
181     double *eJA = calloc(nob*nvb, sizeof(double));
182     double *xia = calloc(nmo*noa*nva, sizeof(double));
183     double *xja = calloc(nmo*noa*nva, sizeof(double));
184     double *xJA = calloc(nmo*nob*nvb, sizeof(double));
185     double *exJA = calloc(nmo*nob*nvb, sizeof(double));
186 
187     double *vv_priv = calloc(nmo*nmo, sizeof(double));
188     double *vev_priv = calloc(nmo*nmo, sizeof(double));
189 
190     int i;
191 
192 #pragma omp for
193     for (i = istart; i < iend; i++) {
194         // build qx
195         AGF2slice_01i(qxi, naux, nmo, noa, i, qx);
196 
197         // build qa
198         AGF2slice_0i2(qja, naux, noa, nva, i, qa);
199 
200         // build xija = xq * qja
201         dgemm_(&TRANS_N, &TRANS_T, &nja, &nmo, &naux, &D1, qja, &nja, qx, &nmo, &D0, xja, &nja);
202 
203         // build xiJA = xq * qJA
204         dgemm_(&TRANS_N, &TRANS_T, &nJA, &nmo, &naux, &D1, qJA, &nJA, qx, &nmo, &D0, xJA, &nJA);
205 
206         // build xjia = xiq * qa
207         dgemm_(&TRANS_N, &TRANS_T, &nva, &nxi, &naux, &D1, qa, &nva, qxi, &nxi, &D0, xia, &nva);
208 
209         // build eija = ei + ej - ea
210         AGF2sum_inplace_ener(e_i[i], e_i, e_a, noa, nva, eja);
211 
212         // build eiJA = ei + eJ - eA
213         AGF2sum_inplace_ener(e_i[i], e_I, e_A, nob, nvb, eJA);
214 
215         // inplace xjia = xija - xjia
216         AGF2sum_inplace(xja, xia, nmo*nja, ss_factor, -ss_factor);
217 
218         // vv_xy += xija * (yija - yjia)
219         dgemm_(&TRANS_T, &TRANS_N, &nmo, &nmo, &nja, &D1, xia, &nja, xja, &nja, &D1, vv_priv, &nmo);
220 
221         // vv_xy += xiJA * yiJA
222         dgemm_(&TRANS_T, &TRANS_N, &nmo, &nmo, &nJA, &os_factor, xJA, &nJA, xJA, &nJA, &D1, vv_priv, &nmo);
223 
224         // inplace xija = eija * xija
225         AGF2prod_inplace_ener(eja, xja, nmo, nja);
226 
227         // outplace xiJA = eiJA * xiJA
228         AGF2prod_outplace_ener(eJA, xJA, nmo, nJA, exJA);
229 
230         // vev_xy += xija * eija * (yija - yjia)
231         dgemm_(&TRANS_T, &TRANS_N, &nmo, &nmo, &nja, &D1, xia, &nja, xja, &nja, &D1, vev_priv, &nmo);
232 
233         // vev_xy += xiJA * eiJA * yiJA
234         dgemm_(&TRANS_T, &TRANS_N, &nmo, &nmo, &nJA, &os_factor, xJA, &nJA, exJA, &nJA, &D1, vev_priv, &nmo);
235     }
236 
237     free(qa);
238     free(qx);
239     free(eja);
240     free(eJA);
241     free(xia);
242     free(xja);
243     free(xJA);
244     free(exJA);
245 
246 #pragma omp critical
247     for (i = 0; i < (nmo*nmo); i++) {
248         vv[i] += vv_priv[i];
249         vev[i] += vev_priv[i];
250     }
251 
252     free(vv_priv);
253     free(vev_priv);
254 }
255 }
256 
257 
258 /*
259  *  Removes an index from DGEMM and into a for loop to reduce the
260  *  thread-private memory overhead, at the cost of serial speed
261  */
AGF2udf_vv_vev_islice_lowmem(double * qxi,double * qja,double * qJA,double * e_i,double * e_I,double * e_a,double * e_A,double os_factor,double ss_factor,int nmo,int noa,int nob,int nva,int nvb,int naux,int start,int end,double * vv,double * vev)262 void AGF2udf_vv_vev_islice_lowmem(double *qxi,
263                                   double *qja,
264                                   double *qJA,
265                                   double *e_i,
266                                   double *e_I,
267                                   double *e_a,
268                                   double *e_A,
269                                   double os_factor,
270                                   double ss_factor,
271                                   int nmo,
272                                   int noa,
273                                   int nob,
274                                   int nva,
275                                   int nvb,
276                                   int naux,
277                                   int start,
278                                   int end,
279                                   double *vv,
280                                   double *vev)
281 {
282     const double D0 = 0.0;
283     const double D1 = 1.0;
284     const char TRANS_T = 'T';
285     const char TRANS_N = 'N';
286     const int one = 1;
287 
288 #pragma omp parallel
289 {
290     double *qx_i = calloc(naux*nmo, sizeof(double));
291     double *qx_j = calloc(naux*nmo, sizeof(double));
292     double *qa_i = calloc(naux*nva, sizeof(double));
293     double *qa_j = calloc(naux*nva, sizeof(double));
294     double *qA_i = calloc(naux*nvb, sizeof(double));
295     double *qA_j = calloc(naux*nvb, sizeof(double));
296     double *xa_i = calloc(nmo*nva, sizeof(double));
297     double *xa_j = calloc(nmo*nva, sizeof(double));
298     double *xA_i = calloc(nmo*nvb, sizeof(double));
299     double *xA_j = calloc(nmo*nvb, sizeof(double));
300     double *ea = calloc(nva, sizeof(double));
301     double *eA = calloc(nvb, sizeof(double));
302     double *exA_i = calloc(nmo*nvb, sizeof(double));
303 
304     double *vv_priv = calloc(nmo*nmo, sizeof(double));
305     double *vev_priv = calloc(nmo*nmo, sizeof(double));
306 
307     bool do_os, do_ss;
308     int i, j, ij;
309 
310 #pragma omp for
311     for (ij = start; ij < end; ij++) {
312         // i = 0 -> noa
313         // j = 0 -> max(noa, nob)
314         i = ij / ((noa > nob) ? noa : nob);
315         j = ij % ((noa > nob) ? noa : nob);
316 
317         do_os = j < nob;
318         do_ss = j < noa;
319 
320         // build qx_i
321         AGF2slice_01i(qxi, naux, nmo, noa, i, qx_i);
322 
323         // build qx_j
324         AGF2slice_01i(qxi, naux, nmo, noa, j, qx_j);
325 
326         // build qa_i
327         AGF2slice_0i2(qja, naux, noa, nva, i, qa_i);
328 
329         // build qa_j
330         AGF2slice_0i2(qja, naux, noa, nva, j, qa_j);
331 
332         if (do_ss) {
333             // build xija
334             dgemm_(&TRANS_N, &TRANS_T, &nva, &nmo, &naux, &D1, qa_i, &nva, qx_j, &nmo, &D0, xa_i, &nva);
335 
336             // build xjia
337             dgemm_(&TRANS_N, &TRANS_T, &nva, &nmo, &naux, &D1, qa_j, &nva, qx_i, &nmo, &D0, xa_j, &nva);
338 
339             // build eija
340             AGF2sum_inplace_ener(e_i[i], &(e_i[j]), e_a, one, nva, ea);
341 
342             // inplace xjia = xija - xjia
343             AGF2sum_inplace(xa_i, xa_j, nmo*nva, ss_factor, -ss_factor);
344 
345             // vv_xy += xija * (yija - yjia)
346             dgemm_(&TRANS_T, &TRANS_N, &nmo, &nmo, &nva, &D1, xa_j, &nva, xa_i, &nva, &D1, vv_priv, &nmo);
347 
348             // inplace xija = eija * xija
349             AGF2prod_inplace_ener(ea, xa_i, nmo, nva);
350 
351             // vev_xy += xija * eija * (yija - yjia)
352             dgemm_(&TRANS_T, &TRANS_N, &nmo, &nmo, &nva, &D1, xa_j, &nva, xa_i, &nva, &D1, vev_priv, &nmo);
353         }
354 
355         if (do_os) {
356             // build qA_j
357             AGF2slice_0i2(qJA, naux, nob, nvb, j, qA_j);
358 
359             // build xiJA
360             dgemm_(&TRANS_N, &TRANS_T, &nvb, &nmo, &naux, &D1, qA_j, &nvb, qx_i, &nmo, &D0, xA_i, &nvb);
361 
362             // build eiJA
363             AGF2sum_inplace_ener(e_i[i], &(e_I[j]), e_A, one, nvb, eA);
364 
365             // vv_xy += xiJA * yiJA
366             dgemm_(&TRANS_T, &TRANS_N, &nmo, &nmo, &nvb, &os_factor, xA_i, &nvb, xA_i, &nvb, &D1, vv_priv, &nmo);
367 
368             // outplace xiJA = eiJA * xiJA
369             AGF2prod_outplace_ener(eA, xA_i, nmo, nvb, exA_i);
370 
371             // vev_xy += xiJA * eiJA * yiJA
372             dgemm_(&TRANS_T, &TRANS_N, &nmo, &nmo, &nvb, &os_factor, xA_i, &nvb, exA_i, &nvb, &D1, vev_priv, &nmo);
373         }
374     }
375 
376     free(qx_i);
377     free(qx_j);
378     free(qa_i);
379     free(qa_j);
380     free(qA_i);
381     free(qA_j);
382     free(xa_i);
383     free(xa_j);
384     free(xA_i);
385     free(xA_j);
386     free(ea);
387     free(eA);
388     free(exA_i);
389 
390 #pragma omp critical
391     for (i = 0; i < (nmo*nmo); i++) {
392         vv[i] += vv_priv[i];
393         vev[i] += vev_priv[i];
394     }
395 
396     free(vv_priv);
397     free(vev_priv);
398 }
399 }
400